{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Reference Implementation for NFS - GMM-40 Benchmark\n",
        "\n",
        "This Jupyter Notebook provides a reference implementation of our proposed method, specifically tailored for the GMM-40 benchmark. It is designed to be self-contained and reproducible.\n",
        "\n",
        "**Important Note on Dependencies:** We utilize `uv` for dependency management. The provided installation commands will install dependencies to your system's Python environment. It is **strongly recommended** to use a virtual environment to avoid conflicts with existing packages.\n",
        "\n",
        "## Getting Started\n",
        "\n",
        "### Prerequisites\n",
        "* Python (version >= 3.12 recommended)\n",
        "\n",
        "### Installation\n",
        "\n",
        "1.  **Set up a Virtual Environment (Recommended):**\n",
        "    Before proceeding, create and activate a virtual environment:\n",
        "    ```bash\n",
        "    python -m venv .venv\n",
        "    source .venv/bin/activate  # On Windows, use: .venv\\Scripts\\activate\n",
        "    ```\n",
        "\n",
        "2.  **Install `uv`:**\n",
        "    `uv` is used for fast dependency resolution and installation.\n",
        "    ```python\n",
        "    !pip install uv\n",
        "    ```\n",
        "\n",
        "3.  **Install Dependencies using `uv`:**\n",
        "    The following command will install all necessary libraries into your current Python environment (ideally, your activated virtual environment).\n",
        "    ```python\n",
        "    !uv pip install jax jmp chex equinox optax matplotlib jaxtyping diffrax distrax POT blackjax\n",
        "    ```\n",
        "    * **JAX & Equinox:** For high-performance numerical computation and neural network models.\n",
        "    * **Optax:** For gradient processing and optimization.\n",
        "    * **Diffrax:** For numerical differential equation solvers.\n",
        "    * **Distrax:** For probabilistic modeling.\n",
        "    * **POT (Python Optimal Transport):** For Wasserstein distance calculations.\n",
        "    * **BlackJAX:** For MCMC algorithms like HMC.\n",
        "    * **Matplotlib:** For plotting.\n",
        "    * **Chex & Jaxtyping:** For type checking and assertions in JAX.\n",
        "\n",
        "\n",
        "## Configuration\n",
        "\n",
        "The behavior of the training and evaluation is controlled by the `TrainingExperimentConfig` object and its nested dataclasses (e.g., `ModelConfig`, `TrainingConfig`, `MCMCConfig`). You can modify the parameters in the \"Train and Eval\" section to experiment with different settings:\n",
        "* **`SamplingConfig`**: Number of particles and timesteps for sampling.\n",
        "* **`TrainingConfig`**: Epochs, learning rate, optimizer settings, batch sizes, shortcut parameters, etc.\n",
        "* **`MCMCConfig`**: MCMC method (VSMC, HMC, etc.), number of steps, step sizes.\n",
        "* **`IntegrationConfig`**: ODE solver method and time schedule.\n",
        "* **`ModelConfig`**: Neural network architecture details (hidden dimensions, layers).\n",
        "* **`DensityConfig`**: initial distribution parameters.\n",
        "\n",
        "Modify these configurations in the \"Train and Eval\" cell block to explore different experimental setups."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0hahbmwaCd3g"
      },
      "source": [
        "## Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "QbsI0dJdCSyR",
        "outputId": "f6ab2581-9f3f-43f8-b221-b47521707cb9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Collecting uv\n",
            "  Downloading uv-0.7.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
            "Downloading uv-0.7.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (17.4 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m17.4/17.4 MB\u001b[0m \u001b[31m68.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hInstalling collected packages: uv\n",
            "Successfully installed uv-0.7.6\n"
          ]
        }
      ],
      "source": [
        "!pip install uv"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 32,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Gb0YiZrBCcvs",
        "outputId": "059d24ac-70c6-4558-9cba-3f8bde8e4709"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "\u001b[2mUsing Python 3.11.12 environment at: /usr\u001b[0m\n",
            "\u001b[2K\u001b[2mResolved \u001b[1m41 packages\u001b[0m \u001b[2min 194ms\u001b[0m\u001b[0m\n",
            "\u001b[2K\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[1A\u001b[37m⠙\u001b[0m \u001b[2mPreparing packages...\u001b[0m (0/2)\n",
            "\u001b[2K\u001b[2mPrepared \u001b[1m2 packages\u001b[0m \u001b[2min 186ms\u001b[0m\u001b[0m\n",
            "\u001b[2K\u001b[2mInstalled \u001b[1m2 packages\u001b[0m \u001b[2min 6ms\u001b[0m\u001b[0m\n",
            " \u001b[32m+\u001b[39m \u001b[1mblackjax\u001b[0m\u001b[2m==1.2.5\u001b[0m\n",
            " \u001b[32m+\u001b[39m \u001b[1mjaxopt\u001b[0m\u001b[2m==0.8.5\u001b[0m\n"
          ]
        }
      ],
      "source": [
        "!uv pip install --system jax jmp chex equinox optax matplotlib jaxtyping diffrax distrax POT blackjax"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aqgz2t_jCxLr"
      },
      "source": [
        "## Define Network Parameterisation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 68,
      "metadata": {
        "id": "aricH8SiCugU"
      },
      "outputs": [],
      "source": [
        "import equinox as eqx\n",
        "import jax\n",
        "import chex\n",
        "import jax.numpy as jnp\n",
        "from typing import Callable\n",
        "\n",
        "def init_linear_weights(\n",
        "    model: eqx.Module, init_fn: Callable, key: jax.random.PRNGKey, scale: float = 1.0\n",
        ") -> eqx.Module:\n",
        "    \"\"\"Initialize weights of all Linear layers in a model using the given initialization function.\"\"\"\n",
        "    is_linear = lambda x: isinstance(x, eqx.nn.Linear)\n",
        "    get_weights = lambda m: [\n",
        "        x.weight\n",
        "        for x in jax.tree_util.tree_leaves(m, is_leaf=is_linear)\n",
        "        if is_linear(x)\n",
        "    ]\n",
        "    weights = get_weights(model)\n",
        "    new_weights = [\n",
        "        init_fn(weight, subkey, scale)\n",
        "        for weight, subkey in zip(weights, jax.random.split(key, len(weights)))\n",
        "    ]\n",
        "    return eqx.tree_at(get_weights, model, new_weights)\n",
        "\n",
        "def xavier_init(\n",
        "    weight: jnp.ndarray, key: jax.random.PRNGKey, scale: float = 1.0\n",
        ") -> jnp.ndarray:\n",
        "    \"\"\"Xavier (Glorot) initialization.\"\"\"\n",
        "    out, in_ = weight.shape\n",
        "    bound = jnp.sqrt(6 / (in_ + out))\n",
        "    return scale * jax.random.uniform(\n",
        "        key, shape=(out, in_), minval=-bound, maxval=bound\n",
        "    )\n",
        "\n",
        "\n",
        "class AdaptiveFeatureProjection(eqx.Module):\n",
        "    \"\"\"Conditioning module for time and distance\"\"\"\n",
        "\n",
        "    time_mlp: eqx.nn.MLP\n",
        "    dist_mlp: eqx.nn.MLP\n",
        "    transform: eqx.nn.Linear\n",
        "    activation: eqx.nn.Lambda\n",
        "\n",
        "    def __init__(self, dim, key, activation=jax.nn.silu):\n",
        "        t_key, d_key, proj_key = jax.random.split(key, 3)\n",
        "        self.time_mlp = eqx.nn.MLP(1, dim, dim, 2, key=t_key)\n",
        "        self.dist_mlp = eqx.nn.MLP(1, dim, dim, 2, key=d_key)\n",
        "        self.transform = eqx.nn.Linear(dim * 2, dim, key=proj_key)\n",
        "        self.activation = eqx.nn.Lambda(activation)\n",
        "\n",
        "    def __call__(self, t: chex.Array, d: chex.Array):\n",
        "        t_feat = self.activation(self.time_mlp(t))\n",
        "        d_feat = self.activation(self.dist_mlp(d))\n",
        "\n",
        "        concat = jnp.concatenate([t_feat, d_feat], axis=-1)\n",
        "        return self.transform(concat)\n",
        "\n",
        "\n",
        "class VelocityFieldTwo(eqx.Module):\n",
        "    input_proj: eqx.nn.Linear\n",
        "    blocks: list\n",
        "    norm: eqx.nn.LayerNorm\n",
        "    output_proj: eqx.nn.Linear\n",
        "    conditioning: AdaptiveFeatureProjection\n",
        "    shortcut: bool\n",
        "    dt: float\n",
        "\n",
        "    def __init__(self, key, dim, hidden_dim, depth=6, shortcut=False, dt=0.01):\n",
        "        keys = jax.random.split(key, 6)\n",
        "        self.shortcut = shortcut\n",
        "        self.dt = dt\n",
        "\n",
        "        # Input processing\n",
        "        in_dim = dim + 2 if shortcut else dim + 1\n",
        "        self.input_proj = eqx.nn.Linear(in_dim, hidden_dim, key=keys[0])\n",
        "\n",
        "        # Residual blocks\n",
        "        self.blocks = [\n",
        "            eqx.nn.Sequential(\n",
        "                [\n",
        "                    eqx.nn.Linear(hidden_dim, hidden_dim, key=k),\n",
        "                    eqx.nn.LayerNorm(hidden_dim),\n",
        "                    eqx.nn.Lambda(jax.nn.gelu),\n",
        "                ]\n",
        "            )\n",
        "            for k in jax.random.split(keys[1], depth)\n",
        "        ]\n",
        "\n",
        "        # Conditioning system\n",
        "        self.conditioning = AdaptiveFeatureProjection(hidden_dim, keys[2])\n",
        "\n",
        "        # Output projection\n",
        "        self.norm = eqx.nn.LayerNorm(hidden_dim)\n",
        "        self.output_proj = eqx.nn.Linear(hidden_dim, dim, key=keys[3])\n",
        "        self._init_weights(keys[4])\n",
        "\n",
        "    def _init_weights(self, key):\n",
        "        \"\"\"Ensure stable initialization with dt scaling\"\"\"\n",
        "        self.output_proj = init_linear_weights(\n",
        "            self.output_proj, xavier_init, key, scale=self.dt\n",
        "        )\n",
        "\n",
        "    def __call__(self, x: chex.Array, t: float, d: float = None):\n",
        "        if d is not None and isinstance(d, float):\n",
        "            d = jnp.array([d])\n",
        "        if isinstance(t, float):\n",
        "            t = jnp.array([t])\n",
        "\n",
        "        if d is not None:\n",
        "            d = d.reshape(1)\n",
        "\n",
        "        t = t.reshape(1)\n",
        "        # Prepare inputs\n",
        "        if self.shortcut:\n",
        "            inputs = jnp.concatenate([x, t, d])\n",
        "        else:\n",
        "            inputs = jnp.concatenate([x, t])\n",
        "\n",
        "        # Project to hidden space\n",
        "        h = self.input_proj(inputs)\n",
        "\n",
        "        # Get conditioning features\n",
        "        cond = self.conditioning(t, d) if self.shortcut else 0.0\n",
        "\n",
        "        # Process through blocks\n",
        "        for block in self.blocks:\n",
        "            h = block(h + cond)  # Additive conditioning\n",
        "\n",
        "        # Final projection\n",
        "        return self.output_proj(self.norm(h))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 68,
      "metadata": {
        "id": "QO7A0baYDK2u"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i5vr2YfHDrJj"
      },
      "source": [
        "## Define Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 69,
      "metadata": {
        "id": "hGBd2LKfEDUO"
      },
      "outputs": [],
      "source": [
        "from typing import Optional\n",
        "from functools import partial\n",
        "\n",
        "def _exact_logp_wrapper_with_shortcut(\n",
        "    t: float,\n",
        "    state: Tuple[Float[Array, \"batch dim\"], Float[Array, \"batch\"]],\n",
        "    args: Tuple,\n",
        "    forward: bool = True,\n",
        ") -> Tuple[Tuple, Float[Array, \"batch\"]]:\n",
        "    \"\"\"Exact divergence computation with shortcut support.\"\"\"\n",
        "    y, logp = state\n",
        "    eps, func, dt_abs = args\n",
        "\n",
        "    def fn(y):\n",
        "        return func(y, t, dt_abs)\n",
        "\n",
        "    trace = jnp.trace(jax.jacfwd(fn)(y))\n",
        "    f = fn(y)\n",
        "\n",
        "    return f, -trace if forward else trace\n",
        "\n",
        "def _exact_logp_wrapper(t, state, args, forward: bool = True):\n",
        "    y, logp = state\n",
        "    eps, func = args\n",
        "\n",
        "    def fn(y):\n",
        "        return func(y, t)  # No dt for non-shortcut\n",
        "\n",
        "    trace = jnp.trace(jax.jacfwd(fn)(y))\n",
        "    f = fn(y)\n",
        "\n",
        "    return f, -trace if forward else trace\n",
        "\n",
        "@eqx.filter_jit\n",
        "def solve_neural_ode_diffrax(\n",
        "    v_theta: Callable,\n",
        "    y0: Float[Array, \"batch dim\"],\n",
        "    ts: Float[Array, \"steps\"],\n",
        "    log_p0: Optional[Float[Array, \"batch\"]] = None,\n",
        "    use_shortcut: bool = False,\n",
        "    key: jax.random.PRNGKey = None,\n",
        "    forward: bool = True,\n",
        "    save_trajectory: bool = False,\n",
        "    solver: Optional[diffrax.AbstractSolver] = None,\n",
        ") -> Tuple[Float[Array, \"batch dim\"], Float[Array, \"batch\"]]:\n",
        "    \"\"\"\n",
        "    Solve the neural ODE using Diffrax.\n",
        "\n",
        "    Args:\n",
        "        v_theta: Velocity field accepting (t, y, dt) when use_shortcut=True\n",
        "        final_samples: Target distribution samples [batch, dim]\n",
        "        final_time: End time of forward process\n",
        "        dt: Time step for integration\n",
        "        use_shortcut: Whether velocity field uses time step dt\n",
        "        exact_logp: Compute exact divergence vs Hutchinson's estimator\n",
        "        key: PRNG key for Hutchinson's estimator\n",
        "\n",
        "    Returns:\n",
        "        (base_samples, log_probs)\n",
        "    \"\"\"\n",
        "    t0 = ts[0]\n",
        "    t1 = ts[-1]\n",
        "    dt = ts[1] - ts[0]\n",
        "\n",
        "    # Prepare augmented state (samples + log_probs)\n",
        "    if log_p0 is None:\n",
        "        initial_log_probs = jnp.zeros((y0.shape[0],))\n",
        "    else:\n",
        "        initial_log_probs = log_p0\n",
        "\n",
        "    augmented_state = (y0, initial_log_probs)\n",
        "\n",
        "    # Configure solver and step controller\n",
        "    if solver is None:\n",
        "        solver = diffrax.Euler()\n",
        "\n",
        "    # Prepare arguments based on computation mode\n",
        "    term = diffrax.ODETerm(\n",
        "        partial(_exact_logp_wrapper_with_shortcut, forward=forward)\n",
        "        if use_shortcut\n",
        "        else partial(_exact_logp_wrapper, forward=forward)\n",
        "    )\n",
        "    eps = jnp.zeros_like(y0)  # Dummy\n",
        "\n",
        "    # Special handling for shortcut: precompute absolute time steps\n",
        "    if use_shortcut:\n",
        "        args = (eps, v_theta, jnp.abs(dt))\n",
        "    else:\n",
        "        args = (eps, v_theta)\n",
        "\n",
        "    # Solve the reverse-time ODE\n",
        "    sols = jax.vmap(\n",
        "        lambda x: diffrax.diffeqsolve(\n",
        "            term,\n",
        "            solver,\n",
        "            t0=t0,\n",
        "            t1=t1,\n",
        "            dt0=None,\n",
        "            y0=x,\n",
        "            args=args,\n",
        "            saveat=diffrax.SaveAt(ts=ts)\n",
        "            if save_trajectory\n",
        "            else diffrax.SaveAt(t1=True),\n",
        "            stepsize_controller=diffrax.StepTo(ts=ts),\n",
        "        )\n",
        "    )(augmented_state)\n",
        "\n",
        "    # Extract final state and accumulated log probabilities\n",
        "    samples, log_probs = sols.ys\n",
        "    if save_trajectory:\n",
        "        return jnp.transpose(samples, axes=(1, 0, 2)), jnp.transpose(\n",
        "            log_probs, axes=(1, 0)\n",
        "        )\n",
        "    else:\n",
        "        return samples.reshape(y0.shape), log_probs.reshape(-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 70,
      "metadata": {
        "id": "3_oT6ha-IupY"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "def solve_neural_ode_euler(\n",
        "    v_theta: Callable,\n",
        "    y0: Float[Array, \"batch dim\"],\n",
        "    ts: Float[Array, \"steps\"],\n",
        "    use_shortcut: bool = False,\n",
        "    save_trajectory: bool = False,\n",
        "    **kwargs\n",
        ") -> Tuple[Union[Float[Array, \"batch dim\"], Float[Array, \"steps batch dim\"]],\n",
        "           Union[Float[Array, \"batch\"], Float[Array, \"steps batch\"]]]:\n",
        "    \"\"\"\n",
        "    A simple Euler integrator for neural ODEs as a drop-in replacement for solve_neural_ode_diffrax.\n",
        "    This implementation doesn't compute log probability updates but maintains API compatibility.\n",
        "\n",
        "    Args:\n",
        "        v_theta: Velocity field function\n",
        "        y0: Initial state [batch, dim]\n",
        "        ts: Time points for integration [steps]\n",
        "        use_shortcut: Whether velocity field uses time step dt\n",
        "        save_trajectory: Whether to save intermediate states\n",
        "\n",
        "    Returns:\n",
        "        samples: Final samples [batch, dim] or trajectory [steps, batch, dim]\n",
        "        log_probs: Unchanged log probabilities [batch] or replicated [steps, batch]\n",
        "    \"\"\"\n",
        "    batch_size, dim = y0.shape\n",
        "    num_timesteps = ts.shape[0]\n",
        "\n",
        "    # Define Euler step for a single sample\n",
        "    def euler_step(y, t, next_t):\n",
        "        dt = next_t - t\n",
        "        if use_shortcut:\n",
        "            dy = v_theta(y, t, jnp.abs(dt))\n",
        "        else:\n",
        "            dy = v_theta(y, t)\n",
        "        return y + dt * dy\n",
        "\n",
        "    # Vectorize over batch dimension\n",
        "    vmap_euler_step = jax.vmap(euler_step, in_axes=(0, None, None), out_axes=0)\n",
        "\n",
        "    # Define scan function\n",
        "    def scan_fn(y, t_idx):\n",
        "        t = ts[t_idx]\n",
        "        next_t = ts[t_idx + 1]\n",
        "        next_y = vmap_euler_step(y, t, next_t)\n",
        "        return next_y, next_y\n",
        "\n",
        "    if save_trajectory:\n",
        "        # Compute trajectory using scan\n",
        "        _, trajectory_without_y0 = jax.lax.scan(\n",
        "            scan_fn, y0, jnp.arange(num_timesteps - 1)\n",
        "        )\n",
        "\n",
        "        # Prepend y0 to get full trajectory\n",
        "        trajectory = jnp.concatenate([\n",
        "            jnp.expand_dims(y0, axis=0),\n",
        "            trajectory_without_y0\n",
        "        ], axis=0)\n",
        "\n",
        "        return trajectory, None\n",
        "    else:\n",
        "        # Just compute final state using scan\n",
        "        final_y, _ = jax.lax.scan(\n",
        "            scan_fn, y0, jnp.arange(num_timesteps - 1)\n",
        "        )\n",
        "\n",
        "        return final_y, None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 71,
      "metadata": {
        "id": "UOsZsNrTDtLy"
      },
      "outputs": [],
      "source": [
        "from typing import Callable, Dict, Tuple, Union\n",
        "from jaxtyping import Array, Float, PRNGKeyArray\n",
        "import diffrax\n",
        "\n",
        "@eqx.filter_jit\n",
        "def generate_samples_with_log_prob(\n",
        "    v_theta: Callable[[Float[Array, \"dim\"], float], Float[Array, \"dim\"]],\n",
        "    initial_samples: Float[Array, \"num_samples dim\"],\n",
        "    initial_log_probs: Float[Array, \"num_samples\"],\n",
        "    ts: Float[Array, \"num_timesteps\"],\n",
        "    use_shortcut: bool = False,\n",
        "    solver: str = \"Euler\",\n",
        "    estimate_dt_logZ: bool = False,\n",
        "    **kwargs,\n",
        ") -> Tuple[\n",
        "    Float[Array, \"num_timesteps num_samples dim\"],\n",
        "    Float[Array, \"num_timesteps num_samples\"],\n",
        "]:\n",
        "    final_samples, final_log_probs = solve_neural_ode_diffrax(\n",
        "        v_theta=v_theta if not estimate_dt_logZ else lambda *args: v_theta(*args)[0],\n",
        "        y0=initial_samples,\n",
        "        ts=ts,\n",
        "        log_p0=initial_log_probs,\n",
        "        use_shortcut=use_shortcut,\n",
        "        forward=True,\n",
        "        solver=diffrax.Tsit5() if solver == \"Tsit5\" else diffrax.Euler(),\n",
        "    )\n",
        "    return final_samples, final_log_probs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 72,
      "metadata": {
        "id": "MZRHT5NEEqE6"
      },
      "outputs": [],
      "source": [
        "import itertools\n",
        "from typing import Optional, Tuple, Callable\n",
        "\n",
        "import chex\n",
        "import jax.numpy as jnp\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "\n",
        "\n",
        "def plot_contours(log_prob_func: Callable,\n",
        "                  ax: Optional[plt.Axes] = None,\n",
        "                  bounds: Tuple[float, float] = (-5.0, 5.0),\n",
        "                  grid_width_n_points: int = 200,\n",
        "                  n_contour_levels: Optional[int] = None,\n",
        "                  log_prob_min: float = -1000.0):\n",
        "    \"\"\"Plot contours of a log_prob_func that is defined on 2D\"\"\"\n",
        "    if ax is None:\n",
        "        fig, ax = plt.subplots(1)\n",
        "\n",
        "    # Create grid points using NumPy, as matplotlib handles NumPy arrays well\n",
        "    # and original plot_contours_2D also used NumPy for grid.\n",
        "    x_points_dim1_np = np.linspace(bounds[0], bounds[1], grid_width_n_points)\n",
        "    x_points_dim2_np = np.linspace(bounds[0], bounds[1], grid_width_n_points)\n",
        "\n",
        "    # Prepare input for log_prob_func, matching the structure from user's example\n",
        "    # (flat list of points)\n",
        "    x_points_for_func = np.array(list(itertools.product(x_points_dim1_np, x_points_dim2_np)))\n",
        "\n",
        "    # Call the log_prob_func. It might return JAX array or NumPy array.\n",
        "    # Assuming it can handle NumPy input based on previous plot_contours_2D.\n",
        "    log_p_x = log_prob_func(x_points_for_func)\n",
        "\n",
        "    # Ensure log_p_x is a JAX array for jnp.clip, then convert to NumPy for plotting\n",
        "    log_p_x_jnp = jnp.asarray(log_p_x)\n",
        "    log_p_x_clipped_jnp = jnp.clip(log_p_x_jnp, a_min=log_prob_min)\n",
        "    log_p_x_clipped_np = np.asarray(log_p_x_clipped_jnp)\n",
        "\n",
        "    log_p_x_reshaped_np = log_p_x_clipped_np.reshape((grid_width_n_points, grid_width_n_points))\n",
        "\n",
        "    # Prepare grid for contour plot (meshgrid approach)\n",
        "    # X_np, Y_np = np.meshgrid(x_points_dim1_np, x_points_dim2_np)\n",
        "    # Or, reshape individual coordinate arrays as in user's example:\n",
        "    x_coords_reshaped_np = x_points_for_func[:, 0].reshape((grid_width_n_points, grid_width_n_points))\n",
        "    y_coords_reshaped_np = x_points_for_func[:, 1].reshape((grid_width_n_points, grid_width_n_points))\n",
        "\n",
        "    if n_contour_levels is not None: # Check for None explicitly\n",
        "        ax.contour(x_coords_reshaped_np, y_coords_reshaped_np, log_p_x_reshaped_np, levels=n_contour_levels)\n",
        "    else:\n",
        "        ax.contour(x_coords_reshaped_np, y_coords_reshaped_np, log_p_x_reshaped_np)\n",
        "\n",
        "\n",
        "def plot_marginal_pair(\n",
        "    samples: chex.Array,\n",
        "    ax: Optional[plt.Axes] = None,\n",
        "    marginal_dims: Tuple[int, int] = (0, 1),\n",
        "    bounds: Tuple[float, float] = (-5.0, 5.0),\n",
        "    alpha: float = 0.5,\n",
        "    markersize: float = 1.5,  # Added markersize\n",
        "):\n",
        "    \"\"\"Plot samples from marginal of distribution for a given pair of dimensions.\"\"\"\n",
        "    if not ax:\n",
        "        fig, ax = plt.subplots(1)\n",
        "\n",
        "    # Clip samples using JAX\n",
        "    samples_clipped = jnp.clip(samples, bounds[0], bounds[1])\n",
        "\n",
        "    # Convert to NumPy for plotting, as matplotlib expects NumPy arrays or lists\n",
        "    samples_np = np.asarray(samples_clipped)\n",
        "\n",
        "    ax.plot(\n",
        "        samples_np[:, marginal_dims[0]],\n",
        "        samples_np[:, marginal_dims[1]],\n",
        "        \"o\",\n",
        "        alpha=alpha,\n",
        "        markersize=markersize  # Use markersize\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 73,
      "metadata": {
        "id": "ItuQ_jimFEjw"
      },
      "outputs": [],
      "source": [
        "import ot as pot\n",
        "\n",
        "def compute_w2_distance_1d_pot(x, y):\n",
        "    # return pot.wasserstein_1d(x, y, p=2.0)\n",
        "    return jnp.sqrt(pot.emd2_1d(x, y))\n",
        "\n",
        "\n",
        "def compute_w1_distance_1d_pot(x, y):\n",
        "    return pot.wasserstein_1d(x, y, p=1.0)\n",
        "\n",
        "\n",
        "def compute_total_variation_distance(\n",
        "    samples_p, samples_q, num_bins=200, lower_bound=-5.0, upper_bound=5.0\n",
        "):\n",
        "    \"\"\"\n",
        "    Compute the Total Variation (TV) distance between two distributions (p and q) in N-dimensional space.\n",
        "\n",
        "    Args:\n",
        "    samples_p: Array of shape (n_samples, d) representing samples from distribution P\n",
        "    samples_q: Array of shape (n_samples, d) representing samples from distribution Q\n",
        "    num_bins: Number of bins to use for histogram estimation per dimension\n",
        "    lower_bound: Lower bound of the sample space for each dimension. If None, computed from data.\n",
        "    upper_bound: Upper bound of the sample space for each dimension. If None, computed from data.\n",
        "\n",
        "    Returns:\n",
        "    TV distance: Scalar value representing the Total Variation distance\n",
        "    \"\"\"\n",
        "    # Ensure samples are on CPU before histogram calculation\n",
        "    cpu_device = jax.devices(\"cpu\")[0]\n",
        "    samples_p = jax.device_put(samples_p, cpu_device)\n",
        "    samples_q = jax.device_put(samples_q, cpu_device)\n",
        "\n",
        "    # Create bin edges for each dimension\n",
        "    bin_edges = [\n",
        "        jnp.linspace(lower_bound, upper_bound, num_bins + 1)\n",
        "        for _ in range(samples_p.shape[1])\n",
        "    ]\n",
        "\n",
        "    # Compute histograms for both distributions (normalized)\n",
        "    hist_p, _ = jnp.histogramdd(samples_p, bins=bin_edges, density=True)\n",
        "    hist_q, _ = jnp.histogramdd(samples_q, bins=bin_edges, density=True)\n",
        "\n",
        "    # Normalize histograms explicitly to ensure their sum is 1\n",
        "    hist_p /= jnp.sum(hist_p)\n",
        "    hist_q /= jnp.sum(hist_q)\n",
        "\n",
        "    # Compute Total Variation distance as the half sum of absolute differences\n",
        "    tv_distance = 0.5 * jnp.sum(jnp.abs(hist_p - hist_q))\n",
        "\n",
        "    return tv_distance\n",
        "\n",
        "\n",
        "compute_total_variation_distance = jax.jit(\n",
        "    compute_total_variation_distance, static_argnums=(2, 3, 4)\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 74,
      "metadata": {
        "id": "hkJVW9woH1T6"
      },
      "outputs": [],
      "source": [
        "class AnnealedDistribution:\n",
        "    TIME_DEPENDENT = True\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        initial_density,\n",
        "        target_density,\n",
        "        method: str = \"linear\",\n",
        "        prior_regularization: bool = False,\n",
        "    ):\n",
        "        self.dim = initial_density.dim\n",
        "        self.n_model_samples_eval = 1000\n",
        "        self.n_target_samples_eval = 1000\n",
        "        self.initial_density = initial_density\n",
        "        self.target_density = target_density\n",
        "        self.prior_regularization = prior_regularization\n",
        "        self.method = method\n",
        "\n",
        "    def log_prob(self, xs: chex.Array) -> chex.Array:\n",
        "        return self.time_dependent_log_prob(xs, 1.0)\n",
        "\n",
        "    def base_log_prob(self, xs: chex.Array) -> chex.Array:\n",
        "        return self.initial_density.log_prob(xs)\n",
        "\n",
        "    def time_dependent_log_prob(self, xs: chex.Array, t: chex.Array) -> chex.Array:\n",
        "        beta = t\n",
        "\n",
        "        if self.prior_regularization:\n",
        "            initial_prob = self.initial_density.log_prob(xs)\n",
        "        else:\n",
        "            initial_prob = (1 - beta) * self.initial_density.log_prob(xs)\n",
        "\n",
        "\n",
        "        if self.target_density.TIME_DEPENDENT:\n",
        "            target_prob = beta * self.target_density.time_dependent_log_prob(xs, t)\n",
        "        else:\n",
        "            target_prob = beta * self.target_density.log_prob(xs)\n",
        "\n",
        "        return initial_prob + target_prob\n",
        "\n",
        "    def incremental_log_delta(self, xs: chex.Array, dt: float) -> chex.Array:\n",
        "        return dt * (\n",
        "            self.target_density.log_prob(xs) - self.initial_density.log_prob(xs)\n",
        "        )\n",
        "\n",
        "    def time_derivative(self, xs: chex.Array, t: float) -> chex.Array:\n",
        "        return jax.grad(lambda t: self.time_dependent_log_prob(xs, t))(t)\n",
        "\n",
        "    def score_fn(self, xs: chex.Array, t: float) -> chex.Array:\n",
        "        return jax.grad(lambda x: self.time_dependent_log_prob(x, t))(xs)\n",
        "\n",
        "    def sample_initial(self, key: chex.PRNGKey, sample_shape: chex.Shape) -> chex.Array:\n",
        "        return self.initial_density.sample(key, sample_shape)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DYXDBWCyDUqy"
      },
      "source": [
        "## Define Target Density"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 75,
      "metadata": {
        "id": "vFcUd1YRQAV9"
      },
      "outputs": [],
      "source": [
        "class MultivariateGaussian:\n",
        "    TIME_DEPENDENT = False\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        dim: int = 2,\n",
        "        mean: float = 0.0,\n",
        "        sigma: float = 1.0,\n",
        "        n_samples_eval: int = 1024,\n",
        "    ):\n",
        "        self.dim = dim\n",
        "        self.n_model_samples_eval = n_samples_eval\n",
        "        self.n_target_samples_eval = n_samples_eval\n",
        "        self.sigma = jnp.asarray(sigma)\n",
        "        if self.sigma.ndim == 0:\n",
        "            scale_diag = jnp.full((dim,), self.sigma)\n",
        "        else:\n",
        "            if self.sigma.shape[0] != dim:\n",
        "                raise ValueError(\n",
        "                    f\"Sigma shape {self.sigma.shape} does not match dimension {dim}.\"\n",
        "                )\n",
        "            scale_diag = self.sigma\n",
        "\n",
        "        self.mean = jnp.ones((dim,)) * mean\n",
        "\n",
        "        self.distribution = distrax.MultivariateNormalDiag(\n",
        "            loc=self.mean, scale_diag=scale_diag\n",
        "        )\n",
        "\n",
        "    def log_prob(self, x: chex.Array) -> chex.Array:\n",
        "        return self.distribution.log_prob(x)\n",
        "\n",
        "    def sample(self, seed: chex.PRNGKey, sample_shape: chex.Shape = ()) -> chex.Array:\n",
        "        return self.distribution.sample(seed=seed, sample_shape=sample_shape)\n",
        "\n",
        "    def visualise(self, samples: chex.Array) -> plt.Figure:\n",
        "        fig, ax = plt.subplots(1, figsize=(6, 6))\n",
        "        if self.dim == 2:\n",
        "            # Plot contour lines for the distribution\n",
        "            # Create a grid\n",
        "            grid_size = 100\n",
        "            x_lin = jnp.linspace(\n",
        "                self.mean[0] - 3 * self.sigma[0],\n",
        "                self.mean[0] + 3 * self.sigma[0],\n",
        "                grid_size,\n",
        "            )\n",
        "            y_lin = jnp.linspace(\n",
        "                self.mean[1] - 3 * self.sigma[1],\n",
        "                self.mean[1] + 3 * self.sigma[1],\n",
        "                grid_size,\n",
        "            )\n",
        "            X, Y = jnp.meshgrid(x_lin, y_lin)\n",
        "            grid = jnp.stack([X, Y], axis=-1).reshape(-1, 2)  # Shape: (grid_size**2, 2)\n",
        "\n",
        "            # Compute log_prob for each grid point\n",
        "            log_probs = self.log_prob(grid).reshape(grid_size, grid_size)\n",
        "\n",
        "            # Plot contours\n",
        "            ax.contour(X, Y, log_probs, levels=20, cmap=\"viridis\")\n",
        "            ax.set_xlim(\n",
        "                self.mean[0] - 3 * self.sigma[0], self.mean[0] + 3 * self.sigma[0]\n",
        "            )\n",
        "            ax.set_ylim(\n",
        "                self.mean[1] - 3 * self.sigma[1], self.mean[1] + 3 * self.sigma[1]\n",
        "            )\n",
        "\n",
        "            # Overlay scatter plot of samples\n",
        "            ax.scatter(\n",
        "                samples[:, 0],\n",
        "                samples[:, 1],\n",
        "                alpha=0.5,\n",
        "                s=10,\n",
        "                color=\"red\",\n",
        "                label=\"Samples\",\n",
        "            )\n",
        "\n",
        "            ax.set_title(\"Multivariate Gaussian (2D)\")\n",
        "            ax.set_xlabel(\"Dimension 1\")\n",
        "            ax.set_ylabel(\"Dimension 2\")\n",
        "            ax.legend()\n",
        "            ax.grid(True)\n",
        "\n",
        "        return fig\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 76,
      "metadata": {
        "id": "j1peRLSuDW9-"
      },
      "outputs": [],
      "source": [
        "import distrax\n",
        "import matplotlib.pyplot as plt\n",
        "from typing import Callable, Dict\n",
        "\n",
        "class GMM:\n",
        "    TIME_DEPENDENT = False\n",
        "    TARGET_METRIC = ((\"total_variation\", True),)\n",
        "\n",
        "    def __init__(\n",
        "        self,\n",
        "        key: chex.PRNGKey,\n",
        "        dim: int = 2,\n",
        "        n_mixes: int = 40,\n",
        "        loc_scaling: float = 40.0,\n",
        "        scale_scaling: float = 1.0,\n",
        "        fixed_mean: bool = True,\n",
        "        n_samples_eval: int = 1024,\n",
        "    ) -> None:\n",
        "        self.n_mixes = n_mixes\n",
        "        self.n_model_samples_eval = n_samples_eval\n",
        "        self.n_target_samples_eval = n_samples_eval\n",
        "        self.dim = dim\n",
        "\n",
        "        logits = jnp.ones(n_mixes)\n",
        "        if fixed_mean:\n",
        "            mean = jnp.array(\n",
        "                [\n",
        "                    [-0.2995, 21.4577],\n",
        "                    [-32.9218, -29.4376],\n",
        "                    [-15.4062, 10.7263],\n",
        "                    [-0.7925, 31.7156],\n",
        "                    [-3.5498, 10.5845],\n",
        "                    [-12.0885, -7.8626],\n",
        "                    [-38.2139, -26.4913],\n",
        "                    [-16.4889, 1.4817],\n",
        "                    [15.8134, 24.0009],\n",
        "                    [-27.1176, -17.4185],\n",
        "                    [14.5287, 33.2155],\n",
        "                    [-8.2320, 29.9325],\n",
        "                    [-6.4473, 4.2326],\n",
        "                    [36.2190, -37.1068],\n",
        "                    [-25.1815, -10.1266],\n",
        "                    [-15.5920, 34.5600],\n",
        "                    [-25.9272, -18.4133],\n",
        "                    [-27.9456, -37.4624],\n",
        "                    [-23.3496, 34.3839],\n",
        "                    [17.8487, 19.3869],\n",
        "                    [2.1037, -20.5073],\n",
        "                    [6.7674, -37.3478],\n",
        "                    [-28.9026, -20.6212],\n",
        "                    [25.2375, 23.4529],\n",
        "                    [-17.7398, -1.4433],\n",
        "                    [25.5824, 39.7653],\n",
        "                    [15.8753, 5.4037],\n",
        "                    [26.8195, -23.5521],\n",
        "                    [7.4538, -31.0122],\n",
        "                    [-27.7234, -20.6633],\n",
        "                    [18.0989, 16.0864],\n",
        "                    [-23.6941, 12.0843],\n",
        "                    [21.9589, -5.0487],\n",
        "                    [1.5273, 9.2682],\n",
        "                    [24.8151, 38.4078],\n",
        "                    [-30.8249, -14.6588],\n",
        "                    [15.7204, 33.1420],\n",
        "                    [34.8083, 35.2943],\n",
        "                    [7.9606, -34.7833],\n",
        "                    [3.6797, -25.0242],\n",
        "                ]\n",
        "            ) * (loc_scaling / 40.0)\n",
        "        else:\n",
        "            mean = jax.random.normal(key, shape=(n_mixes, dim)) * loc_scaling\n",
        "\n",
        "        scale = jnp.ones(shape=(n_mixes, dim)) * scale_scaling\n",
        "\n",
        "        mixture_dist = distrax.Categorical(logits=logits)\n",
        "        components_dist = distrax.Independent(\n",
        "            distrax.Normal(loc=mean, scale=scale), reinterpreted_batch_ndims=1\n",
        "        )\n",
        "        self.distribution = distrax.MixtureSameFamily(\n",
        "            mixture_distribution=mixture_dist,\n",
        "            components_distribution=components_dist,\n",
        "        )\n",
        "\n",
        "        self._plot_bound = loc_scaling * 1.5\n",
        "\n",
        "    def log_prob(self, x: chex.Array) -> chex.Array:\n",
        "        log_prob = self.distribution.log_prob(x)\n",
        "        return log_prob\n",
        "\n",
        "    def sample(self, seed: chex.PRNGKey, sample_shape: chex.Shape = ()) -> chex.Array:\n",
        "        return self.distribution.sample(seed=seed, sample_shape=sample_shape)\n",
        "\n",
        "    def visualise(\n",
        "        self,\n",
        "        samples: chex.Array,\n",
        "    ) -> plt.Figure:\n",
        "        \"\"\"Visualise samples from the model.\"\"\"\n",
        "        fig, ax = plt.subplots(1, figsize=(4, 3.6))  # Updated figsize\n",
        "        plot_marginal_pair(\n",
        "            samples, ax, bounds=(-self._plot_bound, self._plot_bound), markersize=1.5\n",
        "        )  # Explicitly set markersize for clarity, though it's default\n",
        "        if self.dim == 2:\n",
        "            plot_contours(\n",
        "                self.log_prob,\n",
        "                ax=ax,\n",
        "                bounds=(-self._plot_bound, self._plot_bound),\n",
        "                n_contour_levels=50,\n",
        "                grid_width_n_points=200,\n",
        "            )\n",
        "\n",
        "        plt.axis(\"off\")  # Updated axis styling\n",
        "\n",
        "        return fig\n",
        "\n",
        "    def evaluate(\n",
        "        self,\n",
        "        key: chex.PRNGKey,\n",
        "        *,\n",
        "        v_theta: Callable,\n",
        "        ts: chex.Array,\n",
        "        base_density,\n",
        "        use_shortcut: bool = False,\n",
        "        estimate_dt_logZ: bool = False,\n",
        "        learnable_path: bool = False,\n",
        "        **kwargs,\n",
        "    ) -> Dict[str, float]:\n",
        "        metrics = {}\n",
        "\n",
        "        key, sample_key = jax.random.split(key)\n",
        "        initial_samples = base_density.sample(\n",
        "            sample_key, (self.n_model_samples_eval,)\n",
        "        )  # Sample from base distribution q_0\n",
        "        initial_log_probs = base_density.log_prob(initial_samples)\n",
        "\n",
        "        samples_q, samples_log_q = generate_samples_with_log_prob(\n",
        "            v_theta=v_theta,\n",
        "            initial_samples=initial_samples,\n",
        "            initial_log_probs=initial_log_probs,\n",
        "            ts=ts,\n",
        "            use_shortcut=use_shortcut,\n",
        "            estimate_dt_logZ=estimate_dt_logZ,\n",
        "        )\n",
        "\n",
        "        # --- Ensure all relevant arrays are on CPU before metric calculation ---\n",
        "        cpu_device = jax.devices(\"cpu\")[0]\n",
        "        samples_q = jax.device_put(samples_q, cpu_device)\n",
        "        samples_log_q = jax.device_put(samples_log_q, cpu_device)  # Needed for ESS\n",
        "\n",
        "        if self.dim == 2:\n",
        "            metrics[\"figure\"] = self.visualise(samples_q)  # Visualise uses samples_q\n",
        "\n",
        "        key, sample_key = jax.random.split(key)\n",
        "        # Generate true samples on CPU directly if possible, or move them\n",
        "        true_samples = self.sample(sample_key, (self.n_model_samples_eval,))\n",
        "        true_samples = jax.device_put(true_samples, cpu_device)\n",
        "\n",
        "        # Compute log probs on CPU\n",
        "        log_prob_samples = jax.device_put(self.log_prob(samples_q), cpu_device)\n",
        "        log_prob_true_samples = jax.device_put(self.log_prob(true_samples), cpu_device)\n",
        "\n",
        "        # Compute energy distances with CPU arrays\n",
        "        e_w2_distance = compute_w2_distance_1d_pot(\n",
        "            log_prob_samples,\n",
        "            log_prob_true_samples,\n",
        "        )\n",
        "\n",
        "        e_w1_distance = compute_w1_distance_1d_pot(\n",
        "            log_prob_samples,\n",
        "            log_prob_true_samples,\n",
        "        )\n",
        "\n",
        "        # Compute total variation distances with CPU arrays\n",
        "        if self.dim == 2:\n",
        "            total_variation = compute_total_variation_distance(\n",
        "                samples_q,\n",
        "                true_samples,\n",
        "                num_bins=200,\n",
        "                lower_bound=-self._plot_bound,\n",
        "                upper_bound=self._plot_bound,\n",
        "            )\n",
        "            metrics[\"total_variation\"] = total_variation\n",
        "\n",
        "        energy_total_variation = compute_total_variation_distance(\n",
        "            log_prob_samples.reshape(-1, 1),\n",
        "            log_prob_true_samples.reshape(-1, 1),\n",
        "            num_bins=200,\n",
        "            lower_bound=-100,  # Adjust bounds if necessary\n",
        "            upper_bound=100,  # Adjust bounds if necessary\n",
        "        )\n",
        "        metrics[\"energy_total_variation\"] = energy_total_variation\n",
        "\n",
        "        metrics[\"e_w2_distance\"] = e_w2_distance\n",
        "        metrics[\"e_w1_distance\"] = e_w1_distance\n",
        "\n",
        "        return metrics\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OiIiw4mIFqeU"
      },
      "source": [
        "## Define Objective"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 119,
      "metadata": {
        "id": "-wXuWf8lF4F4"
      },
      "outputs": [],
      "source": [
        "class Particle(eqx.Module):\n",
        "    x: chex.Array\n",
        "    t: chex.Array\n",
        "    log_Z_t: chex.Array\n",
        "    d: Optional[chex.Array] = None\n",
        "    loss_weight: Optional[chex.Array] = None\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 120,
      "metadata": {
        "id": "Hjc-YOd_GFaA"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "def divergence_velocity_with_shortcut(\n",
        "    v_theta: Callable[[chex.Array, float], chex.Array],\n",
        "    x: chex.Array,\n",
        "    t: float,\n",
        "    d: float,\n",
        ") -> chex.Array:\n",
        "    def v_x(x):\n",
        "        return v_theta(x, t, d)\n",
        "\n",
        "    return jnp.trace(jax.jacfwd(v_x)(x))\n",
        "\n",
        "\n",
        "@eqx.filter_jit\n",
        "def divergence_velocity(\n",
        "    v_theta: Callable[[chex.Array, float], chex.Array],\n",
        "    x: chex.Array,\n",
        "    t: float,\n",
        ") -> float:\n",
        "    def v_x(x):\n",
        "        return v_theta(x, t)\n",
        "\n",
        "    return jnp.trace(jax.jacfwd(v_x)(x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 121,
      "metadata": {
        "id": "unjaWi9TGA4u"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "@chex.assert_max_traces(4)\n",
        "def epsilon(\n",
        "    v_theta: Callable[[chex.Array, float, float], chex.Array]\n",
        "    | Callable[[chex.Array, float], chex.Array],\n",
        "    particle: Particle,\n",
        "    score_fn: Callable[[chex.Array, float], chex.Array],\n",
        "    time_derivative_log_density: Callable[[chex.Array, float], float],\n",
        ") -> chex.Array:\n",
        "    \"\"\"Computes the local error using a Particle instance.\"\"\"\n",
        "    x, t, log_Z_t, d = particle.x, particle.t, particle.log_Z_t, particle.d\n",
        "\n",
        "    # Get score vector\n",
        "    score = score_fn(x, t)\n",
        "\n",
        "    # Calculate divergence and velocity with appropriate precision\n",
        "    if d is not None:\n",
        "        div_v = divergence_velocity_with_shortcut(v_theta, x, t, d)\n",
        "        v = v_theta(x, t, d)\n",
        "    else:\n",
        "        div_v = divergence_velocity(v_theta, x, t)\n",
        "        v = v_theta(x, t)\n",
        "\n",
        "    dt_log_unormalised = time_derivative_log_density(x, t)\n",
        "    dt_log_density = dt_log_unormalised - log_Z_t\n",
        "\n",
        "    # Calculate dot product with better numerical stability\n",
        "    # If x and y are in low precision, cast to higher precision for the dot product\n",
        "    v_dot_score = jnp.sum(\n",
        "        v * score\n",
        "    )  # element-wise multiply then sum is more stable than dot product\n",
        "\n",
        "    # Calculate final result and handle NaN/inf values\n",
        "    lhs = div_v + v_dot_score\n",
        "    result = lhs + dt_log_density\n",
        "\n",
        "    # Ensure no NaN or inf values propagate\n",
        "    # chex.assert_tree_all_finite(result)\n",
        "    return jnp.nan_to_num(result, posinf=1.0, neginf=-1.0)\n",
        "    # return result\n",
        "\n",
        "\n",
        "batched_epsilon = jax.vmap(epsilon, in_axes=(None, 0, None, None))\n",
        "\n",
        "\n",
        "def shortcut_with_random_alpha(\n",
        "    v_theta: Callable[[chex.Array, float, float], chex.Array],\n",
        "    x: chex.Array,\n",
        "    t: chex.Array,\n",
        "    d: chex.Array,\n",
        "    alpha: chex.Array,  # New argument to specify the fraction of the interval for the first part\n",
        "):\n",
        "    # Compute the total step size (distance to move)\n",
        "    total_step_size = jnp.clip(t + 2 * d, 0.0, 1.0) - t  # No division by 2 needed now\n",
        "\n",
        "    # Compute the velocity for the first part of the split using alpha\n",
        "    first_part_velocity = v_theta(x, t, alpha * total_step_size)\n",
        "    first_part_state = x + first_part_velocity * alpha * total_step_size\n",
        "\n",
        "    # Compute the velocity for the second part of the split using (1-alpha)\n",
        "    second_part_velocity = v_theta(\n",
        "        first_part_state, t + alpha * total_step_size, (1 - alpha) * total_step_size\n",
        "    )\n",
        "\n",
        "    # Calculate the target shortcut step as the weighted sum of both parts\n",
        "    target_shortcut_step = jax.lax.stop_gradient(\n",
        "        alpha * first_part_velocity + (1 - alpha) * second_part_velocity\n",
        "    )\n",
        "\n",
        "    # Compute the error for shortcut consistency\n",
        "    error = (v_theta(x, t, total_step_size) - target_shortcut_step) ** 2\n",
        "\n",
        "    return error\n",
        "\n",
        "batched_shortcut_with_random_alpha = jax.vmap(\n",
        "    shortcut_with_random_alpha, in_axes=(None, 0, 0, 0, 0)\n",
        ")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 122,
      "metadata": {
        "id": "ZJin2jyBE-qr"
      },
      "outputs": [],
      "source": [
        "def loss_fn(\n",
        "    v_theta: Callable[[chex.Array, float], chex.Array],\n",
        "    particles: Particle,\n",
        "    time_derivative_log_density: Callable[[chex.Array, float], float],\n",
        "    score_fn: Callable[[chex.Array, float], chex.Array],\n",
        "    alpha: Optional[\n",
        "        chex.Array\n",
        "    ] = None,  # Added: Pre-generated alpha for random shortcut\n",
        "    combined_loss: bool = False,\n",
        "    shortcut_weight: float = 0.5,\n",
        "    skip_shortcut: Optional[bool] = False,\n",
        ") -> Tuple[float, Float[Array, \" batch\"]]:  # Return loss value and raw epsilons\n",
        "    \"\"\"\n",
        "    Computes the loss for training the velocity field and returns the raw epsilons.\n",
        "\n",
        "    Args:\n",
        "        v_theta: The velocity field function taking (x, t) and returning velocity vector\n",
        "        particles: Batch of particles, shape (batch_size, num_samples, dim)\n",
        "        time_derivative_log_density: Function computing time derivative of log density\n",
        "        score_fn: Score function taking (x, t) and returning gradient of log density\n",
        "        shift_fn: Function taking (x) and returning shifted x\n",
        "        alpha: Optional array of shape (batch,) for random shortcut interpolation.\n",
        "\n",
        "    Returns:\n",
        "        Tuple containing:\n",
        "            - float: The computed loss value (scalar).\n",
        "            - Array: The raw epsilon values for each particle instance (batch,).\n",
        "    \"\"\"\n",
        "\n",
        "\n",
        "\n",
        "    raw_epsilons = batched_epsilon(\n",
        "        v_theta,\n",
        "        particles,\n",
        "        score_fn,\n",
        "        time_derivative_log_density,\n",
        "    )\n",
        "\n",
        "    # --- Calculate Loss ---\n",
        "    # Apply loss weight if necessary for loss calculation\n",
        "    epsilons_for_loss = raw_epsilons\n",
        "    if particles.loss_weight is not None:\n",
        "        epsilons_for_loss = raw_epsilons * particles.loss_weight\n",
        "\n",
        "    if combined_loss:\n",
        "        # Compute L1 and L2 loss for the (potentially weighted) epsilons\n",
        "        l1_loss = jnp.mean(jnp.abs(epsilons_for_loss))  # L1 (MAE)\n",
        "        l2_loss = jnp.mean(epsilons_for_loss**2)  # L2 (MSE)\n",
        "        _loss = 0.5 * l1_loss + 0.5 * l2_loss  # Adjust weights as needed\n",
        "    else:\n",
        "        _loss = jnp.mean(epsilons_for_loss**2)\n",
        "\n",
        "    # Add shortcut loss if applicable\n",
        "    if particles.d is not None and alpha is not None and not skip_shortcut:\n",
        "        # Use random alpha shortcut if alpha values are provided\n",
        "        short_cut_loss = batched_shortcut_with_random_alpha(\n",
        "            v_theta, particles.x, particles.t, particles.d, alpha\n",
        "        )\n",
        "        final_loss = _loss + shortcut_weight * jnp.mean(short_cut_loss)\n",
        "    else:\n",
        "        final_loss = _loss\n",
        "\n",
        "    # Return the final loss value AND the raw (unweighted) epsilons\n",
        "    return final_loss, raw_epsilons"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5nXiC0f-Hq32"
      },
      "source": [
        "## Define Configuration Class"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 123,
      "metadata": {
        "id": "H4f05M9YF5Qy"
      },
      "outputs": [],
      "source": [
        "from dataclasses import dataclass, field\n",
        "from typing import Any, Callable, List, Literal, Optional\n",
        "\n",
        "import jmp\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class SamplingConfig:\n",
        "    num_particles: int = 512  # N: Number of particles to simulate\n",
        "    num_timesteps: int = 32  # T: Number of timesteps for integration\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class TrainingConfig:\n",
        "    num_epochs: int = 200\n",
        "    steps_per_epoch: int = 100\n",
        "    learning_rate: float = 1e-3\n",
        "    gradient_clip_norm: Optional[float] = None\n",
        "    gradient_clip: Optional[float] = None\n",
        "    eval_frequency: int = 20\n",
        "    eval_loss_curve: bool = False\n",
        "    optimizer: Literal[\n",
        "        \"adam\",\n",
        "        \"adamw\",\n",
        "        \"sgd\",\n",
        "        \"rmsprop\",\n",
        "        \"adafactor\",\n",
        "        \"adagrad\",\n",
        "        \"adadelta\",\n",
        "        \"lamb\",\n",
        "        \"lion\",\n",
        "        \"adamax\",\n",
        "        \"fromage\",\n",
        "        \"noisy_sgd\",\n",
        "        \"lbfgs\",\n",
        "    ] = \"adamw\"\n",
        "    use_decoupled_loss: bool = False  # Whether to use decoupled loss function\n",
        "    training_data: Literal[\"combined\", \"vsmc\", \"random\"] = \"combined\"\n",
        "    # Optimizer parameters\n",
        "    weight_decay: float = 0.0\n",
        "    beta1: float = 0.9  # b1 for Adam-like optimizers\n",
        "    beta2: float = 0.999  # b2 for Adam-like optimizers\n",
        "    epsilon: float = 1e-8  # eps for numerical stability\n",
        "    momentum: float = 0.9  # momentum for SGD\n",
        "    nesterov: bool = False  # whether to use Nesterov momentum\n",
        "    noise_scale: float = 0.01  # eta for noisy SGD\n",
        "    # LBFGS specific parameters\n",
        "    lbfgs_memory_size: int = 10  # Memory size for LBFGS history\n",
        "    # Parameters for the LBFGS-specific training procedure in lbfgs_core.py\n",
        "    lbfgs_batch_size: int = 256  # Size of the fixed batch used in the inner LBFGS loop\n",
        "    lbfgs_max_inner_iterations: int = (\n",
        "        1000  # Max iterations for the inner LBFGS loop on a fixed batch\n",
        "    )\n",
        "    lbfgs_convergence_window: int = (\n",
        "        5  # Window size (N) for checking loss convergence in inner loop\n",
        "    )\n",
        "    lbfgs_convergence_threshold: float = (\n",
        "        1e-2  # Loss standard deviation threshold for inner loop convergence\n",
        "    )\n",
        "    # --- End LBFGS specific procedure parameters ---\n",
        "    time_batch_size: int = (\n",
        "        32  # Number of time points to use in each batch (standard training)\n",
        "    )\n",
        "    shortcut_size: List[int] = field(default_factory=lambda: [16, 32, 64, 128])\n",
        "    use_shortcut: bool = False\n",
        "    random_alpha: bool = False\n",
        "    estimator: str = \"hutchinson\"\n",
        "    n_probes: int = 5\n",
        "    r: int = 4\n",
        "    every_k_schedule: int = 1\n",
        "    use_schedule: bool = False\n",
        "    use_combined_loss: bool = False\n",
        "    shortcut_weight: float = 0.5\n",
        "    # Learning rate schedule parameters\n",
        "    schedule_init_value: float = 1e-5  # Initial learning rate value for schedule\n",
        "    schedule_warmup_steps: int = 10000  # Number of warmup steps\n",
        "    schedule_decay_epoch: int = 10  # Epoch at which to start decaying the learning rate\n",
        "    schedule_end_value: float = 1e-7  # Final learning rate value after decay\n",
        "    # Log Z estimation frequency (only used when use_decoupled_loss is True)\n",
        "    log_z_estimation_frequency: int = 1  # How often to estimate log_Z_t (in epochs)\n",
        "    reweight: bool = False  # Whether to use reweighting for log Z estimation\n",
        "    perturb: bool = False\n",
        "    perturbation_scale: float = 1.0\n",
        "    augment: bool = False  # Whether to use augmentations\n",
        "    translation_scale: float = 2.0  # Scale for translation augmentation\n",
        "    skip_shortcut: bool = False  # Whether to skip the shortcut connection\n",
        "    batch_size: int = 256  # B: Batch size for training\n",
        "    # RAD Sampling Config\n",
        "    use_rad_sampling: bool = False\n",
        "    rad_steps: int = 1  # Number of RAD steps *after* the initial step\n",
        "    rad_beta: float = 1.0  # Exponent for RAD probability calculation\n",
        "    rad_use_weights: bool = True  # Whether to use importance weights from RAD sampler\n",
        "    estimate_logz: bool = False  # Whether to use ZNet to estimate log Z\n",
        "    # Fourier Feature\n",
        "    use_nets_estimator: bool = False\n",
        "    learnable_path: bool = False\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class MCMCConfig:\n",
        "    method: Literal[\"hmc\", \"smc\", \"esmc\", \"vsmc\"] = \"hmc\"\n",
        "    num_steps: int = 5\n",
        "    num_integration_steps: int = 3\n",
        "    step_size: float = 0.01  # eta: MCMC step size\n",
        "    with_rejection: bool = False\n",
        "    use_control_variate: bool = False\n",
        "    lambda_max: float = 0.1  # Maximum value for lambda factor\n",
        "    lambda_epochs: float = 2000.0  # Number of epochs over which lambda increases\n",
        "    ess_threshold: float = 0.5\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class IntegrationConfig:\n",
        "    method: Literal[\"Euler\", \"Tsit5\"] = \"Euler\"\n",
        "    schedule: Literal[\"linear\", \"inverse_power\", \"power\", \"focus\"] = (\n",
        "        \"linear\"  # Added \"focus\"\n",
        "    )\n",
        "    continuous_time: bool = False\n",
        "    dt_clip: Optional[float] = None\n",
        "    gamma: Optional[float] = 0.25  # Gamma for inverse power and power schedules\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class ModelConfig:\n",
        "    hidden_dim: int = 256\n",
        "    num_layers: int = 3\n",
        "    num_heads: int = 4\n",
        "    mlp_depth: int = 2\n",
        "    norm: str = \"rms\"\n",
        "    theta: float = 10000.0\n",
        "    geonorm: bool = False\n",
        "    dropout: float = None\n",
        "    architecture: Literal[\"mlp\", \"pdn\", \"transformer\", \"emlp\", \"egnn\"] = \"mlp\"\n",
        "    embedding_dim: int = 64\n",
        "    embedder_width: int = 128\n",
        "    embedder_depth: int = 2\n",
        "    use_fourier_features: bool = False  # Whether to use Fourier features\n",
        "    num_bands: int = 16  # Number of bands for Fourier features\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class DensityConfig:\n",
        "    target_type: Literal[\n",
        "        \"gmm\",\n",
        "        \"mw32\",\n",
        "        \"dw4\",\n",
        "        \"lj13\",\n",
        "        \"sclj13\",\n",
        "        \"dw4o\",\n",
        "        \"tlj13\",\n",
        "        \"lj13b\",\n",
        "        \"lj13bt\",\n",
        "        \"lj13c\",\n",
        "    ] = \"gmm\"\n",
        "    initial_sigma: float = 20.0\n",
        "    score_norm: Optional[float] = None\n",
        "    annealing_path: Literal[\"linear\", \"geometric\"] = \"linear\"\n",
        "    shift_fn: Callable[[Any], Any] = field(default_factory=lambda: lambda x: x)\n",
        "    input_dim: Optional[int] = None\n",
        "    # LJ specific parameters\n",
        "    n_particles: Optional[int] = None\n",
        "    n_spatial_dim: Optional[int] = None\n",
        "    alpha: Optional[float] = None\n",
        "    epsilon_val: Optional[float] = None\n",
        "    min_dr: Optional[float] = 1e-3\n",
        "    m: Optional[int] = 1\n",
        "    n: Optional[int] = 1.0\n",
        "    c: Optional[float] = 1.0\n",
        "    r_min: Optional[float] = 0.2\n",
        "    log_prob_clip: Optional[float] = None\n",
        "    log_prob_clip_min: Optional[float] = None\n",
        "    log_prob_clip_max: Optional[float] = None\n",
        "    soft_clip: bool = False\n",
        "    include_harmonic: bool = True\n",
        "    cubic_spline: bool = False\n",
        "    # Data paths for some targets\n",
        "    data_path_test: Optional[str] = None\n",
        "    data_path_val: Optional[str] = None\n",
        "    data_path_train: Optional[str] = None\n",
        "    n_samples_eval: Optional[int] = 1024\n",
        "    prior_regularization: bool = False\n",
        "\n",
        "\n",
        "@dataclass\n",
        "class TrainingExperimentConfig:\n",
        "    sampling: SamplingConfig = field(default_factory=SamplingConfig)\n",
        "    training: TrainingConfig = field(default_factory=TrainingConfig)\n",
        "    mcmc: MCMCConfig = field(default_factory=MCMCConfig)\n",
        "    integration: IntegrationConfig = field(default_factory=IntegrationConfig)\n",
        "    model: ModelConfig = field(default_factory=ModelConfig)\n",
        "    density: DensityConfig = field(default_factory=DensityConfig)\n",
        "    offline: bool = False\n",
        "    debug: bool = False\n",
        "    mixed_precision: bool = False\n",
        "    resume_from: Optional[str] = None\n",
        "    mp_policy: jmp.Policy = None  # Field to store the JMP mixed precision policy\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hlu0WfH2IPUn"
      },
      "source": [
        "## Define MCMC related methods"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zEFGLeM4JTtp"
      },
      "source": [
        "### HMC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 124,
      "metadata": {
        "id": "tFazZFrfJUsX"
      },
      "outputs": [],
      "source": [
        "import blackjax\n",
        "\n",
        "@eqx.filter_jit\n",
        "def sample_hamiltonian_monte_carlo_blackjax(\n",
        "    key: PRNGKeyArray,\n",
        "    time_dependent_log_density: Callable[[Float[Array, \"dim\"], float], float],\n",
        "    x: Float[Array, \"dim\"],\n",
        "    t: float,\n",
        "    step_size: Float[Array, \"\"], # Now explicitly passed, can be 0D array or float\n",
        "    inverse_mass_matrix: Optional[Float[Array, \"dim dim\"]],\n",
        "    num_integration_steps: int, # Now explicitly passed\n",
        "    num_hmc_steps: int = 1, # Renamed to avoid confusion, typically 1 inside SMC step\n",
        "    **kwargs, # Keep for potential future use / compatibility if needed elsewhere\n",
        ") -> Float[Array, \"dim\"]:\n",
        "    \"\"\"\n",
        "    Hamiltonian Monte Carlo using BlackJAX.\n",
        "\n",
        "    Args:\n",
        "        key: Random key\n",
        "        time_dependent_log_density: Log density function\n",
        "        x: Initial position\n",
        "        t: Time parameter\n",
        "        num_steps: Number of HMC steps\n",
        "        integration_steps: Number of integration steps per HMC step\n",
        "        step_size: Step size\n",
        "        inverse_mass_matrix: Optional covariance matrix/diagonal\n",
        "        **kwargs: Additional arguments (for compatibility)\n",
        "\n",
        "    Returns:\n",
        "        Final position after HMC\n",
        "    \"\"\"\n",
        "    # Use provided inv mass matrix, default to identity if None\n",
        "    dim = x.shape[-1]\n",
        "    _inverse_mass_matrix = jnp.eye(dim) if inverse_mass_matrix is None else inverse_mass_matrix\n",
        "\n",
        "    # Initialize Blackjax HMC kernel with passed parameters\n",
        "    hmc = blackjax.hmc(\n",
        "        logdensity_fn=lambda state: time_dependent_log_density(state, t),\n",
        "        step_size=step_size,\n",
        "        inverse_mass_matrix=_inverse_mass_matrix,\n",
        "        num_integration_steps=num_integration_steps,\n",
        "    )\n",
        "    hmc_kernel = jax.jit(hmc.step)\n",
        "    initial_state = hmc.init(x)\n",
        "\n",
        "    # Define the loop body for sequential HMC steps\n",
        "    @jax.jit\n",
        "    def one_step(state, rng_key):\n",
        "        state, _ = hmc_kernel(rng_key, state)\n",
        "        return state, state # Carry the state, output the state\n",
        "\n",
        "    # Perform HMC steps using lax.scan\n",
        "    keys = jax.random.split(key, num_hmc_steps) # Use num_hmc_steps\n",
        "    final_state, _ = jax.lax.scan(one_step, initial_state, keys)\n",
        "\n",
        "    return final_state.position # Return only the final position\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A0lav3qQI6zm"
      },
      "source": [
        "### SMC"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 125,
      "metadata": {
        "id": "eDJis5duI5Tc"
      },
      "outputs": [],
      "source": [
        "### SMC\n",
        "from blackjax.smc.resampling import systematic\n",
        "from jaxtyping import Int\n",
        "@jax.jit\n",
        "def ess_from_logweights(log_w: jax.Array) -> jax.Array:\n",
        "    \"\"\"Effective sample size given *unnormalised* log-weights.\"\"\"\n",
        "    log_w = log_w - jax.scipy.special.logsumexp(log_w)   # normalise in log-space\n",
        "    w      = jnp.exp(log_w)                              # plain weights\n",
        "    return 1.0 / jnp.sum(w * w)\n",
        "\n",
        "@jax.jit\n",
        "def log_weights_to_weights(log_weights: Float[Array, \"num_samples\"]) -> Float[Array, \"num_samples\"]:\n",
        "    \"\"\"\n",
        "    Convert log weights to weights.\n",
        "\n",
        "    Args:\n",
        "        log_weights: Log weights. Shape: (num_samples,).\n",
        "\n",
        "    Returns:\n",
        "        Weights: Weights. Shape: (num_samples,).\n",
        "    \"\"\"\n",
        "    log_sum_w = jax.scipy.special.logsumexp(log_weights)\n",
        "    log_normalized_weights = log_weights - log_sum_w\n",
        "    weights = jnp.exp(log_normalized_weights)\n",
        "\n",
        "    return weights\n",
        "\n",
        "@eqx.filter_jit\n",
        "# @chex.assert_max_traces(n=5)\n",
        "def generate_samples_with_smc(\n",
        "    key: PRNGKeyArray,\n",
        "    initial_samples: Float[Array, \"num_samples dim\"],\n",
        "    time_dependent_log_density: Callable[[Float[Array, \"dim\"], float], float],\n",
        "    ts: Float[Array, \"num_timesteps\"],\n",
        "    num_mcmc_steps: int = 10,\n",
        "    integration_steps: int = 3,\n",
        "    eta: float = 0.1,\n",
        "    ess_threshold: float = 0.6,\n",
        "    resampling_fn: Callable[\n",
        "        [PRNGKeyArray, Float[Array, \"num_samples\"], int], Int[Array, \"num_samples\"]\n",
        "    ] = systematic,\n",
        "    v_theta: Optional[Callable[[Float[Array, \"dim\"], float], Float[Array, \"dim\"]]] = None,\n",
        "    use_shortcut: bool = False,\n",
        "    initial_log_weights: Optional[Float[Array, \"num_samples\"]] = None,\n",
        "    lambda_factor: Float[Array, \"\"] = 1.0,\n",
        "    hmc_parameters: Optional[Dict] = None, # Add optional HMC parameters Pytree\n",
        "    use_v_theta_only: bool = False, # If True, skip HMC and only use v_theta drift\n",
        ") -> Dict[str, Union[Float[Array, \"num_timesteps num_samples dim\"],\n",
        "                       Float[Array, \"num_timesteps num_samples\"],\n",
        "                       Float[Array, \"num_timesteps\"]]]:\n",
        "\n",
        "    num_samples = initial_samples.shape[0]\n",
        "    # Initialize particles with provided samples or generate new ones\n",
        "    chex.assert_rank(initial_samples, 2)\n",
        "    if initial_log_weights is None:\n",
        "        initial_log_weights = jnp.full((num_samples,), -jnp.log(num_samples), dtype=jnp.float32)\n",
        "\n",
        "    chex.assert_rank(initial_log_weights, 1)\n",
        "\n",
        "    # Split keys for all samples and all timesteps, but exclude the first timestep (since we use initial_samples)\n",
        "    num_scan_steps = ts.shape[0] - 1  # We need one less step than the number of timestamps\n",
        "    sample_keys = jax.random.split(key, num_samples * num_scan_steps).reshape(\n",
        "        num_scan_steps, num_samples, -1\n",
        "    )\n",
        "\n",
        "    # Initial particles at time ts[0]\n",
        "    particles = {\n",
        "        \"positions\": initial_samples,\n",
        "        \"log_weights\": initial_log_weights,\n",
        "        \"ess\": jnp.array(1.0),\n",
        "    }\n",
        "\n",
        "    def _delta(positions, t, t_prev):\n",
        "        return time_dependent_log_density(\n",
        "            positions, t\n",
        "        ) - time_dependent_log_density(positions, t_prev)\n",
        "\n",
        "    batched_delta = jax.vmap(_delta, in_axes=(0, None, None))\n",
        "    if v_theta is not None:\n",
        "        if use_shortcut:\n",
        "            batched_v_theta = jax.vmap(v_theta, in_axes=(0, None, None))\n",
        "        else:\n",
        "            batched_v_theta = jax.vmap(v_theta, in_axes=(0, None))\n",
        "\n",
        "    # Hoist HMC vmap outside scan and JIT compile, capturing static num_mcmc_steps\n",
        "    batched_hmc_step = jax.jit(jax.vmap(\n",
        "        lambda key, pos, t, step_size, inv_mass_matrix, num_integration_steps: sample_hamiltonian_monte_carlo_blackjax(\n",
        "            key=key,\n",
        "            time_dependent_log_density=time_dependent_log_density,\n",
        "            x=pos,\n",
        "            t=t,\n",
        "            step_size=step_size,\n",
        "            inverse_mass_matrix=inv_mass_matrix,\n",
        "            num_integration_steps=num_integration_steps,\n",
        "            num_hmc_steps=num_mcmc_steps,  # static capture\n",
        "        ),\n",
        "        in_axes=(0, 0, None, None, None, None)\n",
        "    ))\n",
        "\n",
        "    @jax.jit\n",
        "    def _resample(key: PRNGKeyArray,\n",
        "                  positions: Float[Array, \"num_samples dim\"],\n",
        "                  log_weights: Float[Array, \"num_samples\"]\n",
        "                 ) -> Tuple[Float[Array, \"num_samples dim\"], Float[Array, \"num_samples\"]]:\n",
        "        \"\"\"\n",
        "        Resample particles based on their log weights.\n",
        "\n",
        "        Args:\n",
        "            key: JAX PRNG key.\n",
        "            positions: Current particle positions. Shape: (num_samples, ...).\n",
        "            log_weights: Current log weights. Shape: (num_samples,).\n",
        "\n",
        "        Returns:\n",
        "            new_positions: Resampled particle positions. Shape: (num_samples, ...).\n",
        "            new_log_weights: Reset log weights. Shape: (num_samples,).\n",
        "        \"\"\"\n",
        "        # Normalize log_weights to prevent numerical underflow/overflow\n",
        "        weights = log_weights_to_weights(log_weights)\n",
        "        # Perform resampling to obtain indices\n",
        "        indices = resampling_fn(key, weights, num_samples)  # Shape: (num_samples,)\n",
        "\n",
        "        # Resample positions\n",
        "        new_positions = jnp.take(positions, indices, axis=0)\n",
        "\n",
        "        # Reset log_weights to uniform\n",
        "        new_log_weights = jnp.full((num_samples,), -jnp.log(num_samples), dtype=jnp.float32)\n",
        "\n",
        "        return new_positions, new_log_weights\n",
        "\n",
        "    def step(carry, inputs):\n",
        "        particles_prev, t_idx = carry\n",
        "        # Unpack inputs: keys, covariance for this step, hmc params for this step\n",
        "        keys, hmc_params_t = inputs\n",
        "\n",
        "        # Get current and next time from ts\n",
        "        t_prev = ts[t_idx]\n",
        "        t = ts[t_idx + 1]\n",
        "\n",
        "        # Time step for ODE integration\n",
        "        d = t - t_prev\n",
        "\n",
        "        # Compute ESS and Resample if necessary\n",
        "        # ess_val = ess(log_weights=particles_prev[\"log_weights\"])  # Scalar\n",
        "        ess_val = ess_from_logweights(particles_prev[\"log_weights\"])  # Scalar\n",
        "        ess_percentage = ess_val / num_samples  # Scalar\n",
        "\n",
        "        # Define the condition for resampling\n",
        "        def do_resample():\n",
        "            resample_key, _ = jax.random.split(keys[0])\n",
        "            # Resample particles\n",
        "            new_positions, new_log_weights = _resample(\n",
        "                resample_key, particles_prev[\"positions\"], particles_prev[\"log_weights\"]\n",
        "            )\n",
        "\n",
        "            return {\"positions\": new_positions, \"log_weights\": new_log_weights}\n",
        "\n",
        "        def do_nothing():\n",
        "            # Keep the particles as is with normalized log weights\n",
        "            log_weights_normalized = particles_prev[\"log_weights\"] - jax.scipy.special.logsumexp(\n",
        "                particles_prev[\"log_weights\"]\n",
        "            )\n",
        "\n",
        "            return {\n",
        "                \"positions\": particles_prev[\"positions\"],\n",
        "                \"log_weights\": log_weights_normalized,\n",
        "            }\n",
        "\n",
        "        # Conditionally resample based on ESS percentage\n",
        "        particles_new = jax.lax.cond(\n",
        "            ess_percentage < ess_threshold,\n",
        "            do_resample,\n",
        "            do_nothing,\n",
        "        )\n",
        "        particles_new[\"ess\"] = ess_percentage\n",
        "\n",
        "        # Apply shift function\n",
        "        shifted_positions = particles_new[\"positions\"]\n",
        "        # If v_theta is provided, use it to propagate particles first\n",
        "        if v_theta is not None:\n",
        "            if use_shortcut:\n",
        "                # Match Euler's behavior by using t_prev and absolute dt\n",
        "                propagated_positions = shifted_positions + lambda_factor * d * batched_v_theta(\n",
        "                    shifted_positions, t_prev, d\n",
        "                )\n",
        "            else:\n",
        "                propagated_positions = shifted_positions + lambda_factor * d * batched_v_theta(\n",
        "                    shifted_positions, t_prev\n",
        "                )\n",
        "        else:\n",
        "            propagated_positions = shifted_positions\n",
        "\n",
        "        # --- Determine HMC parameters for this step t ---\n",
        "        if hmc_params_t is not None:\n",
        "            # Use pre-computed parameters for this time step\n",
        "            hmc_step_size = hmc_params_t[\"step_size\"]\n",
        "            hmc_inv_mass_matrix = hmc_params_t[\"inverse_mass_matrix\"]\n",
        "            hmc_num_integration_steps = hmc_params_t[\"num_integration_steps\"]\n",
        "        else:\n",
        "            # Use fallback parameters (eta, integration_steps, cov)\n",
        "            hmc_step_size = eta\n",
        "            hmc_inv_mass_matrix = None # This was determined earlier (lines 190-200)\n",
        "            hmc_num_integration_steps = integration_steps # From outer scope\n",
        "\n",
        "        # --- Define and apply HMC kernel for this step ---\n",
        "        # Apply JIT-compiled batched HMC kernel if not skipping\n",
        "        if not use_v_theta_only:\n",
        "            propagated_positions = batched_hmc_step(\n",
        "                keys,\n",
        "                propagated_positions,\n",
        "                t,\n",
        "                hmc_step_size,\n",
        "                hmc_inv_mass_matrix,\n",
        "                hmc_num_integration_steps,\n",
        "            )  # Shape: (num_samples, ...)\n",
        "        # else: propagated_positions remains the result from the v_theta step\n",
        "\n",
        "        # Compute incremental weights\n",
        "        w_delta = batched_delta(propagated_positions, t, t_prev)\n",
        "        # Update log weights in log space\n",
        "        next_log_weights = particles_new[\"log_weights\"] + w_delta\n",
        "        next_log_weights = next_log_weights - jax.scipy.special.logsumexp(\n",
        "            next_log_weights\n",
        "        )\n",
        "\n",
        "        # Update carry with new particles and next time index\n",
        "        # Include the ess key to match the input carry structure\n",
        "        new_carry = (\n",
        "            {\n",
        "                \"positions\": propagated_positions,\n",
        "                \"log_weights\": next_log_weights,\n",
        "                \"ess\": particles_new[\"ess\"],  # Make sure to include ess in the new carry\n",
        "            },\n",
        "            t_idx + 1,\n",
        "        )\n",
        "\n",
        "        # Output particles at time t\n",
        "        return new_carry, {\n",
        "            \"positions\": propagated_positions,\n",
        "            \"log_weights\": next_log_weights,\n",
        "            \"ess\": particles_new[\"ess\"],\n",
        "        }\n",
        "\n",
        "    # Prepare HMC parameters for scan (for t=1 to T)\n",
        "    if hmc_parameters is not None:\n",
        "        # Slice parameters Pytree for each step (t=1 to T)\n",
        "        # Assumes parameters have shape (num_timesteps, ...) matching ts\n",
        "        chex.assert_tree_shape_prefix(hmc_parameters, (ts.shape[0],)) # Sanity check\n",
        "        scan_hmc_params = jax.tree.map(lambda x: x[1:], hmc_parameters)\n",
        "    else:\n",
        "        # If no pre-computed params, create a placeholder input for scan.\n",
        "        # Pass None for each step; the step function handles this.\n",
        "        # We need *something* with the correct length for scan to iterate over.\n",
        "        # num_scan_steps = ts.shape[0] - 1\n",
        "        # scan_hmc_params = [None] * num_scan_steps # List of Nones\n",
        "        scan_hmc_params = None\n",
        "\n",
        "    # Run scan over time indices from 0 to num_timesteps-2\n",
        "    # This will generate particles at times ts[1] to ts[num_timesteps-1]\n",
        "    _, scan_particles = jax.lax.scan(\n",
        "        step,\n",
        "        (particles, 0),  # Initial carry: particles at ts[0] and time index 0\n",
        "        (sample_keys, scan_hmc_params),  # Inputs for each time step\n",
        "    )\n",
        "\n",
        "    # Now we need to include the initial particles at ts[0]\n",
        "    all_positions = jnp.concatenate([\n",
        "        jnp.expand_dims(particles[\"positions\"], axis=0),\n",
        "        scan_particles[\"positions\"]\n",
        "    ], axis=0)\n",
        "\n",
        "    all_log_weights = jnp.concatenate([\n",
        "        jnp.expand_dims(particles[\"log_weights\"], axis=0),\n",
        "        scan_particles[\"log_weights\"]\n",
        "    ], axis=0)\n",
        "\n",
        "    all_ess = jnp.concatenate([\n",
        "        jnp.expand_dims(particles[\"ess\"], axis=0),\n",
        "        scan_particles[\"ess\"]\n",
        "    ], axis=0)\n",
        "\n",
        "    # Convert all log weights to weights\n",
        "    all_weights = jax.vmap(log_weights_to_weights)(all_log_weights)\n",
        "\n",
        "    return {\n",
        "        \"positions\": all_positions,\n",
        "        \"weights\": all_weights,\n",
        "        \"ess\": all_ess,\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 126,
      "metadata": {
        "id": "p7Sq4G-LIZ36"
      },
      "outputs": [],
      "source": [
        "# @eqx.filter_jit\n",
        "def sample_with_mcmc(\n",
        "    key: PRNGKeyArray,\n",
        "    initial_samples: Float[Array, \"num_samples dim\"],\n",
        "    v_theta: Optional[Callable[[Float[Array, \"dim\"], float], Float[Array, \"dim\"]]] = None,\n",
        "    time_dependent_log_density: Optional[Callable[[Float[Array, \"dim\"], float], float]] = None,\n",
        "    incremental_log_delta: Optional[Callable[[Float[Array, \"dim\"], float], float]] = None,\n",
        "    mcmc_method: str = \"none\",\n",
        "    ts: Optional[Float[Array, \"num_timesteps\"]] = None,\n",
        "    initial_log_probs: Optional[Float[Array, \"num_samples\"]] = None,\n",
        "    shift_fn: Callable[[Float[Array, \"dim\"]], Float[Array, \"dim\"]] = lambda x: x,\n",
        "    use_shortcut: bool = False,\n",
        "    num_steps: int = 10,\n",
        "    integration_steps: int = 3,\n",
        "    eta: float = 0.1,\n",
        "    ess_threshold: float = 0.6,\n",
        "    estimate_covariance: bool = False,\n",
        "    covariance: Optional[Float[Array, \"dim dim\"]] = None,\n",
        "    solver: str = \"Euler\",\n",
        "    lambda_factor: Float[Array, \"\"] = 1.0,\n",
        "    estimate_dt_logZ: bool = False,\n",
        "    **kwargs\n",
        ") -> Dict[str, Union[Float[Array, \"num_timesteps num_samples dim\"],\n",
        "                    Float[Array, \"num_timesteps num_samples\"],\n",
        "                    Float[Array, \"num_timesteps\"]]]:\n",
        "    \"\"\"\n",
        "    Unified interface for generating samples with MCMC methods.\n",
        "\n",
        "    This function provides a standardized interface for all MCMC sampling\n",
        "    methods, including direct sampling (no MCMC), Hamiltonian Monte Carlo (HMC),\n",
        "    and Sequential Monte Carlo (SMC).\n",
        "\n",
        "    Args:\n",
        "        key: Random key\n",
        "        initial_samples: Initial samples with shape (num_samples, dim)\n",
        "        v_theta: Velocity field model (optional)\n",
        "        time_dependent_log_density: Log density function (t, x) -> log p(x, t)\n",
        "        mcmc_method: MCMC method to use (\"none\", \"hmc\", \"smc\", \"vsmc\")\n",
        "        ts: Time steps with shape (num_timesteps,)\n",
        "        initial_log_probs: Optional log probabilities for initial samples with shape (num_samples,)\n",
        "        shift_fn: Function to shift samples (for periodic boundaries, etc.)\n",
        "        use_shortcut: Whether to use shortcut mechanism for velocity field\n",
        "        num_steps: Number of MCMC steps\n",
        "        integration_steps: Number of integration steps per MCMC step\n",
        "        eta: Step size for MCMC\n",
        "        ess_threshold: Threshold for resampling in SMC\n",
        "        estimate_covariance: Whether to estimate covariance in SMC\n",
        "        covariance: Fixed covariance matrix (optional) with shape (dim, dim)\n",
        "        solver: Integration solver to use (\"Euler\" or \"Tsit5\")\n",
        "        lambda_factor: Factor for adjusting the importance of the velocity field in SMC\n",
        "        **kwargs: Additional arguments\n",
        "\n",
        "    Returns:\n",
        "        Dictionary containing:\n",
        "        - positions: Generated samples with shape (num_timesteps, num_samples, dim)\n",
        "        - weights: Sample weights with shape (num_timesteps, num_samples)\n",
        "        - (optional) ess: Effective sample size for SMC with shape (num_timesteps,)\n",
        "        - (optional) log_probs: Log probabilities with shape (num_timesteps, num_samples)\n",
        "    \"\"\"\n",
        "    # Assert initial_samples dimensions\n",
        "    chex.assert_rank(initial_samples, 2)\n",
        "\n",
        "    if initial_log_probs is None:\n",
        "        initial_log_probs = jnp.zeros(initial_samples.shape[0])\n",
        "\n",
        "    chex.assert_rank(initial_log_probs, 1)\n",
        "\n",
        "    # Apply MCMC based on the specified method\n",
        "    if mcmc_method == \"none\":\n",
        "        # No MCMC, just uniform weights and run the ode\n",
        "        output_samples, _ = solve_neural_ode_euler(\n",
        "            v_theta=v_theta if not estimate_dt_logZ else lambda *args: v_theta(*args)[0],\n",
        "            y0=initial_samples,\n",
        "            ts=ts,\n",
        "            use_shortcut=use_shortcut,\n",
        "            exact_logp=True,\n",
        "            forward=True,\n",
        "            save_trajectory=True,\n",
        "        )\n",
        "\n",
        "        weights = jnp.ones((ts.shape[0], initial_samples.shape[0])) / initial_samples.shape[0]\n",
        "        return {\n",
        "            \"positions\": output_samples,\n",
        "            \"weights\": weights,\n",
        "        }\n",
        "    elif mcmc_method in [\"smc\", \"vsmc\"]:\n",
        "        # Use SMC or VSMC\n",
        "        key, subkey = jax.random.split(key)\n",
        "\n",
        "        # For VSMC, we use the velocity field\n",
        "        v_theta_smc = v_theta if mcmc_method == \"vsmc\" else None\n",
        "\n",
        "        return generate_samples_with_smc(\n",
        "            key=subkey,\n",
        "            initial_samples=initial_samples,\n",
        "            time_dependent_log_density=time_dependent_log_density,\n",
        "            ts=ts,\n",
        "            num_mcmc_steps=num_steps,\n",
        "            integration_steps=integration_steps,\n",
        "            eta=eta,\n",
        "            ess_threshold=ess_threshold,\n",
        "            v_theta=v_theta_smc,\n",
        "            use_shortcut=use_shortcut,\n",
        "            lambda_factor=lambda_factor,\n",
        "        )\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown MCMC method: {mcmc_method}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 127,
      "metadata": {
        "id": "khL21dupHvPC"
      },
      "outputs": [],
      "source": [
        "def generate_samples_with_optional_mcmc(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: Callable,\n",
        "    ts: Float[Array, \" time\"],\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    config: TrainingExperimentConfig,\n",
        "    mcmc_method: Optional[str] = None,\n",
        "    force_finite: bool = False,\n",
        "    num_samples: Optional[int] = None,\n",
        "    lambda_factor: Float[Array, \"\"] = 1.0,\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"\n",
        "    Generate samples with or without MCMC correction.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        v_theta: Velocity field model\n",
        "        ts: Time steps\n",
        "        path_distribution: Annealed distribution\n",
        "        config: Training experiment configuration\n",
        "        mcmc_method: MCMC method to use (e.g., \"smc\", \"hmc\", \"none\"). Defaults to config.mcmc.method.\n",
        "        force_finite: Whether to replace non-finite values with finite ones\n",
        "        num_samples: Number of sample particles. Defaults to config.sampling.num_particles.\n",
        "        lambda_factor: Factor controlling the contribution of the velocity field\n",
        "\n",
        "    Returns:\n",
        "        Samples dictionary containing at least \"positions\" and \"weights\".\n",
        "    \"\"\"\n",
        "    # We generate samples in full precision\n",
        "    ts_compute = config.mp_policy.cast_to_output(ts)\n",
        "\n",
        "    # Determine MCMC method\n",
        "    mcmc_method = config.mcmc.method if mcmc_method is None else mcmc_method\n",
        "\n",
        "    _num_samples = config.sampling.num_particles if num_samples is None else num_samples\n",
        "\n",
        "    # Split key for independent operations\n",
        "    key_initial, key_mcmc = jax.random.split(key)\n",
        "\n",
        "    # Generate initial samples\n",
        "    initial_samples = path_distribution.sample_initial(\n",
        "        key_initial, (_num_samples,)\n",
        "    ).astype(config.mp_policy.output_dtype)\n",
        "\n",
        "    # Use unified MCMC interface\n",
        "    samples = sample_with_mcmc(\n",
        "        key=key_mcmc,  # Use the dedicated key\n",
        "        initial_samples=initial_samples,\n",
        "        v_theta=v_theta,\n",
        "        time_dependent_log_density=path_distribution.time_dependent_log_prob,\n",
        "        incremental_log_delta=path_distribution.incremental_log_delta,\n",
        "        mcmc_method=mcmc_method,\n",
        "        ts=ts_compute,\n",
        "        shift_fn=config.density.shift_fn,\n",
        "        use_shortcut=config.training.use_shortcut,\n",
        "        num_steps=config.mcmc.num_steps,\n",
        "        integration_steps=config.mcmc.num_integration_steps,\n",
        "        eta=config.mcmc.step_size,\n",
        "        ess_threshold=config.mcmc.ess_threshold,\n",
        "        estimate_covariance=False,  # Currently hardcoded\n",
        "        solver=config.integration.method,\n",
        "        lambda_factor=lambda_factor,\n",
        "        estimate_dt_logZ=config.training.estimate_logz,\n",
        "    )\n",
        "\n",
        "    if force_finite:\n",
        "        samples[\"positions\"] = jnp.nan_to_num(\n",
        "            samples[\"positions\"], nan=0.0, posinf=1.0, neginf=-1.0\n",
        "        )\n",
        "    chex.assert_type(samples[\"positions\"], config.mp_policy.output_dtype)\n",
        "    chex.assert_shape(\n",
        "        samples[\"positions\"], (None, _num_samples, None)\n",
        "    )  # (time, num_samples, dim)\n",
        "\n",
        "    return samples"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ifMH_H9bJmcu"
      },
      "source": [
        "## Training Boilerplate"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 128,
      "metadata": {
        "id": "5nQKWezXIRuf"
      },
      "outputs": [],
      "source": [
        "jitted_loss_fn = eqx.filter_jit(loss_fn)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 129,
      "metadata": {
        "id": "2kPBYk6TJqqT"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "def rad_sampler_joint(\n",
        "    key: jax.random.PRNGKey,\n",
        "    particles: Particle,\n",
        "    epsilons: Float[Array, \"\"],\n",
        "    beta: float = 1.0,\n",
        "    path_dist: AnnealedDistribution = None,\n",
        "    use_weights: bool = False,\n",
        ") -> Tuple[Particle, Float[Array, \"\"]]:\n",
        "    \"\"\"\n",
        "    Residual‑based Adaptive Distribution over joint (t,x).\n",
        "    Returns (Particle, weight).\n",
        "    \"\"\"\n",
        "    key_cand, key_pick = jax.random.split(key)\n",
        "\n",
        "    r = (epsilons**2) ** beta + 1e-12\n",
        "    pi = r / jnp.sum(r)\n",
        "    pi = pi.reshape(-1)  # Ensure pi is 1D\n",
        "\n",
        "    # ---- 3. resample ----------------------------------------------------\n",
        "    idx = jax.random.choice(\n",
        "        key_pick,\n",
        "        particles.x.shape[0],\n",
        "        shape=(particles.x.shape[0],),\n",
        "        replace=False,\n",
        "        p=pi,\n",
        "    )\n",
        "\n",
        "    part = Particle(\n",
        "        x=particles.x[idx],\n",
        "        t=particles.t[idx],\n",
        "        log_Z_t=particles.log_Z_t[idx] if particles.log_Z_t is not None else None,\n",
        "        d=None if particles.d is None else particles.d[idx],\n",
        "    )\n",
        "\n",
        "    # ---- 4. optional IS weight -----------------------------------------\n",
        "    if use_weights:\n",
        "        # log p_t(x) is available from path_dist.logpdf\n",
        "        logP = jax.vmap(path_dist.time_dependent_log_prob, in_axes=(0, 0))(\n",
        "            part.x, part.t\n",
        "        )\n",
        "        w = jnp.exp(logP) / (pi[idx] * part.x.shape[0])\n",
        "    else:\n",
        "        w = jnp.ones((part.x.shape[0],), dtype=part.x.dtype)\n",
        "\n",
        "    return part, w"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 130,
      "metadata": {
        "id": "78hKu2RBJuo6"
      },
      "outputs": [],
      "source": [
        "import optax\n",
        "\n",
        "def soft_clip(x, min_val, max_val, alpha=10.0):\n",
        "    return min_val + (max_val - min_val) / (1 + jnp.exp(-alpha * (x - min_val)))\n",
        "\n",
        "\n",
        "def get_optimizer(\n",
        "    name: str,\n",
        "    learning_rate: float | optax.Schedule | None,\n",
        "    weight_decay: float = 0.0,\n",
        "    b1: float = 0.9,\n",
        "    b2: float = 0.999,\n",
        "    eps: float = 1e-8,\n",
        "    momentum: float = 0.9,\n",
        "    nesterov: bool = False,\n",
        "    **kwargs,\n",
        ") -> Union[optax.GradientTransformation, optax.GradientTransformationExtraArgs]:\n",
        "    \"\"\"Creates optimizer based on name and parameters.\n",
        "\n",
        "    Args:\n",
        "        name: Name of optimizer ('adam', 'adamw', 'sgd', 'rmsprop', 'adafactor', 'adagrad', 'adadelta', 'lamb', 'lion', 'adamax', 'fromage', 'noisy_sgd', 'lbfgs')\n",
        "        learning_rate: Learning rate for the optimizer\n",
        "        weight_decay: Weight decay coefficient (L2 regularization)\n",
        "        b1: First moment decay rate (for Adam-like optimizers)\n",
        "        b2: Second moment decay rate (for Adam-like optimizers)\n",
        "        eps: Small constant for numerical stability\n",
        "        momentum: Momentum coefficient for SGD\n",
        "        nesterov: Whether to use Nesterov momentum\n",
        "        **kwargs: Additional optimizer-specific parameters\n",
        "\n",
        "    Returns:\n",
        "        optax.GradientTransformation: The configured optimizer\n",
        "    \"\"\"\n",
        "    if weight_decay > 0 and name not in [\"adamw\", \"lamb\"]:\n",
        "        # Add weight decay as a separate transformation for optimizers that don't include it\n",
        "        base = None\n",
        "        if name == \"adam\":\n",
        "            base = optax.adam(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps)\n",
        "        elif name == \"sgd\":\n",
        "            base = optax.sgd(\n",
        "                learning_rate=learning_rate, momentum=momentum, nesterov=nesterov\n",
        "            )\n",
        "        elif name == \"rmsprop\":\n",
        "            base = optax.rmsprop(\n",
        "                learning_rate=learning_rate,\n",
        "                decay=b1,\n",
        "                eps=eps,\n",
        "                momentum=momentum,\n",
        "                nesterov=nesterov,\n",
        "            )\n",
        "        elif name == \"adafactor\":\n",
        "            base = optax.adafactor(learning_rate=learning_rate)\n",
        "        elif name == \"adagrad\":\n",
        "            base = optax.adagrad(learning_rate=learning_rate, eps=eps)\n",
        "        elif name == \"adadelta\":\n",
        "            base = optax.adadelta(learning_rate=learning_rate, eps=eps)\n",
        "        elif name == \"lion\":\n",
        "            base = optax.lion(learning_rate=learning_rate, b1=b1, b2=b2)\n",
        "        elif name == \"adamax\":\n",
        "            base = optax.adamax(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps)\n",
        "        elif name == \"fromage\":\n",
        "            base = optax.fromage(learning_rate=learning_rate)\n",
        "        elif name == \"noisy_sgd\":\n",
        "            base = optax.noisy_sgd(\n",
        "                learning_rate=learning_rate, eta=kwargs.get(\"noise_scale\", 0.01)\n",
        "            )\n",
        "\n",
        "        if base is not None:\n",
        "            return optax.chain(optax.add_decayed_weights(weight_decay), base)\n",
        "\n",
        "    # Optimizers with built-in weight decay or no weight decay needed\n",
        "    if name == \"adamw\":\n",
        "        return optax.adamw(\n",
        "            learning_rate=learning_rate,\n",
        "            b1=b1,\n",
        "            b2=b2,\n",
        "            eps=eps,\n",
        "            weight_decay=weight_decay,\n",
        "            nesterov=nesterov,\n",
        "        )\n",
        "    elif name == \"lamb\":\n",
        "        return optax.lamb(\n",
        "            learning_rate=learning_rate,\n",
        "            b1=b1,\n",
        "            b2=b2,\n",
        "            eps=eps,\n",
        "            weight_decay=weight_decay,\n",
        "        )\n",
        "    elif name == \"adam\":\n",
        "        return optax.adam(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps)\n",
        "    elif name == \"sgd\":\n",
        "        return optax.sgd(\n",
        "            learning_rate=learning_rate, momentum=momentum, nesterov=nesterov\n",
        "        )\n",
        "    elif name == \"rmsprop\":\n",
        "        return optax.rmsprop(\n",
        "            learning_rate=learning_rate,\n",
        "            decay=b1,\n",
        "            eps=eps,\n",
        "            momentum=momentum,\n",
        "            nesterov=nesterov,\n",
        "        )\n",
        "    elif name == \"adafactor\":\n",
        "        return optax.adafactor(learning_rate=learning_rate)\n",
        "    elif name == \"adagrad\":\n",
        "        return optax.adagrad(learning_rate=learning_rate, eps=eps)\n",
        "    elif name == \"adadelta\":\n",
        "        return optax.adadelta(learning_rate=learning_rate, eps=eps)\n",
        "    elif name == \"lion\":\n",
        "        return optax.lion(learning_rate=learning_rate, b1=b1, b2=b2)\n",
        "    elif name == \"adamax\":\n",
        "        return optax.adamax(learning_rate=learning_rate, b1=b1, b2=b2, eps=eps)\n",
        "    elif name == \"fromage\":\n",
        "        return optax.fromage(learning_rate=learning_rate)\n",
        "    elif name == \"noisy_sgd\":\n",
        "        return optax.noisy_sgd(\n",
        "            learning_rate=learning_rate, eta=kwargs.get(\"noise_scale\", 0.01)\n",
        "        )\n",
        "    else:\n",
        "        raise ValueError(f\"Unknown optimizer: {name}\")\n",
        "\n",
        "\n",
        "def _setup_optimizer(\n",
        "    config: TrainingExperimentConfig,\n",
        ") -> Tuple[optax.GradientTransformation, Callable]:\n",
        "    \"\"\"Sets up the optimizer and learning rate schedule based on the configuration.\"\"\"\n",
        "    if config.training.gradient_clip_norm is not None:\n",
        "        gradient_clipping = optax.clip_by_global_norm(\n",
        "            config.training.gradient_clip_norm\n",
        "        )\n",
        "    elif config.training.gradient_clip is not None:\n",
        "        gradient_clipping = optax.clip(config.training.gradient_clip)\n",
        "    else:\n",
        "        gradient_clipping = optax.identity()\n",
        "\n",
        "    lr_schedule_fn = lambda step: config.training.learning_rate\n",
        "\n",
        "    base_optimizer = get_optimizer(\n",
        "        config.training.optimizer,\n",
        "        lr_schedule_fn,  # Pass the function itself\n",
        "        weight_decay=config.training.weight_decay,\n",
        "        b1=config.training.beta1,\n",
        "        b2=config.training.beta2,\n",
        "        eps=config.training.epsilon,\n",
        "        momentum=config.training.momentum,\n",
        "        nesterov=config.training.nesterov,\n",
        "        noise_scale=config.training.noise_scale,\n",
        "    )\n",
        "\n",
        "    optimizer = optax.chain(optax.zero_nans(), gradient_clipping, base_optimizer)\n",
        "\n",
        "    if config.training.every_k_schedule > 1:\n",
        "        optimizer = optax.MultiSteps(\n",
        "            optimizer, every_k_schedule=config.training.every_k_schedule\n",
        "        )\n",
        "\n",
        "    # Consider adding apply_if_finite if needed:\n",
        "    # optimizer = optax.apply_if_finite(optimizer, max_finite_updates=5)\n",
        "\n",
        "    return optimizer, lr_schedule_fn  # Return schedule_fn as well for logging"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 131,
      "metadata": {
        "id": "tDlBULkcKHS2"
      },
      "outputs": [],
      "source": [
        "def _setup_path_distribution(\n",
        "    initial_density, target_density, config: TrainingExperimentConfig\n",
        ") -> AnnealedDistribution:\n",
        "    \"\"\"Creates the annealed path distribution.\"\"\"\n",
        "    return AnnealedDistribution(\n",
        "        initial_density=initial_density,\n",
        "        target_density=target_density,\n",
        "        method=config.density.annealing_path,\n",
        "    )"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 132,
      "metadata": {
        "id": "14ybZTAUKPBn"
      },
      "outputs": [],
      "source": [
        "from jaxtyping import PyTree\n",
        "import jax.sharding as jshard\n",
        "\n",
        "\n",
        "@eqx.filter_jit\n",
        "def _execute_jitted_step(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: PyTree,\n",
        "    opt_state: PyTree,\n",
        "    particles: Particle,  # Dynamic arg\n",
        "    optimizer: optax.GradientTransformation,  # Static arg\n",
        "    path_distribution_time_derivative: Callable,  # Static arg\n",
        "    path_distribution_score_fn: Callable,  # Static arg\n",
        "    config_density_shift_fn: Callable,  # Static arg\n",
        "    config_training_estimator: str,  # Static arg\n",
        "    config_training_n_probes: int,  # Static arg\n",
        "    config_training_use_combined_loss: bool,  # Static arg\n",
        "    config_training_shortcut_weight: float,  # Static arg\n",
        "    config_training_random_alpha: bool,  # Static arg - Re-added\n",
        "    config_model_dropout: Optional[float],  # Static arg\n",
        "    config_skip_shortcut: bool,  # Static arg\n",
        "    config_estimate_dt_LogZ: bool,  # Static arg\n",
        "    config_learnable_path: bool,  # Static arg\n",
        "    # Sharding specifications (static, passed in):\n",
        "    param_replication: Optional[jshard.NamedSharding],  # Static arg\n",
        "    data_sharding_rank1: Optional[jshard.NamedSharding],  # Static arg\n",
        "    data_sharding_rank2: Optional[jshard.NamedSharding],  # Static arg\n",
        "    probe_sharding_rank2: Optional[jshard.NamedSharding],  # Static arg\n",
        "    probe_sharding_rank3: Optional[jshard.NamedSharding],  # Static arg\n",
        ") -> Tuple[\n",
        "    PyTree, PyTree, Float[Array, \"\"], Float[Array, \" batch\"]\n",
        "]:  # Added raw_epsilons to return type\n",
        "    \"\"\"\n",
        "    Performs a single training step (loss, gradients, update). JIT compiled.\n",
        "    Handles data parallelism sharding.\n",
        "    Also returns the raw epsilons calculated for the given particles and updated model,\n",
        "    to be potentially used for RAD sampling.\n",
        "    \"\"\"\n",
        "    # Split key for potentially independent random operations: probes, alpha, dropout\n",
        "    key_probes, key_alpha, key_dropout = jax.random.split(key, 3)\n",
        "\n",
        "    # --- Generate Probes BEFORE loss function definition (if needed) ---\n",
        "    probes = None\n",
        "    if config_training_estimator == \"hutchinson\":\n",
        "        # Determine shape based on n_probes\n",
        "        if config_training_n_probes > 1:\n",
        "            probe_shape = (\n",
        "                particles.x.shape[0],\n",
        "                config_training_n_probes,\n",
        "                particles.x.shape[1],\n",
        "            )  # Shape: (batch, n_probes, dim)\n",
        "        else:\n",
        "            probe_shape = (\n",
        "                particles.x.shape[0],\n",
        "                particles.x.shape[1],\n",
        "            )  # Shape: (batch, dim)\n",
        "\n",
        "        probes = jax.random.rademacher(\n",
        "            key_probes,  # Use dedicated key\n",
        "            shape=probe_shape,\n",
        "            dtype=particles.x.dtype,\n",
        "        )\n",
        "        # Note: key_alpha is used below if random_alpha shortcut is enabled.\n",
        "\n",
        "    # --- Apply sharding (using pre-defined static sharding specs) ---\n",
        "    if param_replication is not None:  # Check if any sharding is active\n",
        "        # Shard model and optimizer state (replication)\n",
        "        v_theta, opt_state = eqx.filter_shard((v_theta, opt_state), param_replication)\n",
        "\n",
        "        # Shard probes if they exist, using the appropriate pre-defined spec\n",
        "        if probes is not None:\n",
        "            if jnp.ndim(probes) == 3:  # (batch, n_probes, dim)\n",
        "                probes = eqx.filter_shard(probes, probe_sharding_rank3)\n",
        "            else:  # Rank 2 (batch, dim)\n",
        "                probes = eqx.filter_shard(probes, probe_sharding_rank2)\n",
        "\n",
        "        # Create sharding PyTree for particles using pre-defined specs\n",
        "        particle_sharding = jax.tree.map(\n",
        "            lambda leaf: data_sharding_rank1\n",
        "            if jnp.ndim(leaf) == 1\n",
        "            else data_sharding_rank2,\n",
        "            particles,  # Use the actual particle structure\n",
        "            is_leaf=lambda x: x is None,  # Treat None leaves correctly\n",
        "        )\n",
        "        # Apply sharding to the particles PyTree\n",
        "        particles = eqx.filter_shard(particles, particle_sharding)\n",
        "\n",
        "    # Define the loss function specific to this step's context\n",
        "    # This internal function now computes BOTH loss and raw epsilons\n",
        "    @eqx.filter_jit\n",
        "    def compute_loss_and_epsilons(\n",
        "        model, current_particles, current_probes, current_alpha\n",
        "    ):  # Pass alpha explicitly\n",
        "        # loss_fn now returns (loss_value, raw_epsilons)\n",
        "        loss_value, raw_epsilons = loss_fn(\n",
        "            model,\n",
        "            current_particles,  # Use passed particles\n",
        "            path_distribution_time_derivative,\n",
        "            path_distribution_score_fn,\n",
        "            alpha=current_alpha,  # Pass captured alpha\n",
        "            combined_loss=config_training_use_combined_loss,\n",
        "            shortcut_weight=config_training_shortcut_weight,\n",
        "            skip_shortcut=config_skip_shortcut,\n",
        "        )\n",
        "        return loss_value, raw_epsilons  # Return both\n",
        "\n",
        "    # --- Generate Alpha values if needed ---\n",
        "    alpha_values = None\n",
        "    if (\n",
        "        config_training_random_alpha\n",
        "        and particles.d is not None\n",
        "        and not config_skip_shortcut\n",
        "    ):\n",
        "        alpha_values = jax.random.uniform(key_alpha, shape=(particles.x.shape[0],))\n",
        "\n",
        "    # --- Calculate Loss, Gradients, and get Epsilons ---\n",
        "    # Use has_aux=True to get the raw_epsilons returned by compute_loss_and_epsilons\n",
        "    # Pass particles, probes, and alpha to the value_and_grad function\n",
        "    (loss, raw_epsilons), grads = eqx.filter_value_and_grad(\n",
        "        compute_loss_and_epsilons, has_aux=True\n",
        "    )(\n",
        "        v_theta, particles, probes, alpha_values\n",
        "    )  # Pass particles, probes, and alpha here\n",
        "\n",
        "    # --- Apply Updates ---\n",
        "    updates, opt_state = optimizer.update(\n",
        "        grads, opt_state, eqx.filter(v_theta, eqx.is_array)\n",
        "    )\n",
        "    v_theta_updated = eqx.apply_updates(v_theta, updates)\n",
        "\n",
        "    # --- Ensure outputs have correct sharding before returning ---\n",
        "    if param_replication is not None:  # Check if sharding is active\n",
        "        # Shard model and optimizer state (replication)\n",
        "        v_theta_updated, opt_state = eqx.filter_shard(\n",
        "            (v_theta_updated, opt_state), param_replication\n",
        "        )\n",
        "\n",
        "        # Shard raw_epsilons if they exist, using appropriate data sharding spec\n",
        "        if raw_epsilons is not None:\n",
        "            # Use the pre-defined sharding specs passed as arguments\n",
        "            if jnp.ndim(raw_epsilons) == 1:\n",
        "                raw_epsilons = eqx.filter_shard(raw_epsilons, data_sharding_rank1)\n",
        "            else:  # Assume rank 2\n",
        "                raw_epsilons = eqx.filter_shard(raw_epsilons, data_sharding_rank2)\n",
        "        # Loss is scalar, JAX handles reduction automatically.\n",
        "\n",
        "    # --- Return updated state and auxiliary data ---\n",
        "    # No need to recalculate epsilons here!\n",
        "    return v_theta_updated, opt_state, loss, raw_epsilons"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 133,
      "metadata": {
        "id": "oM85Wrj1K2pK"
      },
      "outputs": [],
      "source": [
        "def _prepare_step_batch(\n",
        "    key: jax.random.PRNGKey,\n",
        "    samples: Float[Array, \"time batch dim\"],\n",
        "    current_ts: Float[Array, \" time\"],\n",
        "    log_Z_t: Float[Array, \" time\"],\n",
        "    config: TrainingExperimentConfig,\n",
        ") -> Particle:  # Removed key from return tuple and sampler parameter\n",
        "    \"\"\"\n",
        "    Selects a batch of chains and prepares the Particle object for a training step.\n",
        "\n",
        "    Prepare a batch of (x,t) pairs for one optimisation step.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        samples: All available samples (time, particle, dim).\n",
        "        current_ts: Time steps corresponding to samples.\n",
        "        log_Z_t: Estimated log partition function at each time step.\n",
        "        config: Training configuration.\n",
        "\n",
        "    Returns:\n",
        "        The prepared Particle object for the training step (pre-augmentation).\n",
        "    \"\"\"\n",
        "    # Split key for independent random operations within this function\n",
        "    key_choice, key_uniform = jax.random.split(key)\n",
        "\n",
        "    time_steps, num_total_particles, dim = samples.shape\n",
        "    # Number of distinct trajectories to sample per step\n",
        "    num_chains_per_step = config.training.time_batch_size\n",
        "\n",
        "    _loss_weight = None  # Time IS removed, so loss_weight is always None here\n",
        "\n",
        "    # Standard batch preparation (previously the 'else' block)\n",
        "    if num_chains_per_step > num_total_particles:\n",
        "        print(\n",
        "            f\"Warning: time_batch_size ({num_chains_per_step}) > num available particles ({num_total_particles}). Sampling with replacement.\"\n",
        "        )\n",
        "        replace = True\n",
        "    else:\n",
        "        replace = False\n",
        "\n",
        "    chain_indices = jax.random.choice(\n",
        "        key_choice,\n",
        "        num_total_particles,\n",
        "        shape=(num_chains_per_step,),\n",
        "        replace=replace,  # Use key_choice\n",
        "    )\n",
        "    # Select chains: (time_steps, num_chains_per_step, dim)\n",
        "    selected_chains_time_major = samples[:, chain_indices, :]\n",
        "    # Reshape to batch major: (time_steps * num_chains_per_step, dim)\n",
        "    # This combines time and chain index into the batch dimension for the step\n",
        "    selected_chains = selected_chains_time_major.reshape(\n",
        "        time_steps * num_chains_per_step, dim\n",
        "    )\n",
        "\n",
        "    # Repeat time and log_Z_t for the selected batch\n",
        "    selected_t = jnp.repeat(current_ts, num_chains_per_step)\n",
        "    selected_log_Z_t = (\n",
        "        jnp.repeat(log_Z_t, num_chains_per_step)\n",
        "        if not config.training.estimate_logz\n",
        "        else None\n",
        "    )\n",
        "\n",
        "    # Removed redundant key split: key, subkey = jax.random.split(key)\n",
        "\n",
        "    _d = None\n",
        "    if config.training.use_shortcut:\n",
        "        if config.training.skip_shortcut:\n",
        "            _d = jnp.ones(selected_t.shape) / config.sampling.num_timesteps\n",
        "        else:\n",
        "            _d = jax.random.uniform(key_uniform, selected_t.shape)  # Use key_uniform\n",
        "\n",
        "    training_particles = Particle(\n",
        "        x=selected_chains,  # Will be potentially augmented later\n",
        "        t=selected_t,\n",
        "        log_Z_t=selected_log_Z_t,\n",
        "        d=_d,\n",
        "        loss_weight=_loss_weight,\n",
        "    )\n",
        "    # Return selected_chains before augmentation for the augmentation function\n",
        "    # Removed key from return tuple\n",
        "    return training_particles\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 134,
      "metadata": {
        "id": "ri1D8kcNK_70"
      },
      "outputs": [],
      "source": [
        "def _apply_augmentations(\n",
        "    key: jax.random.PRNGKey,\n",
        "    selected_chains: Float[Array, \"batch_size dim\"],\n",
        "    config: TrainingExperimentConfig,\n",
        ") -> Float[Array, \"batch_size dim\"]:  # Remove key from return tuple\n",
        "    \"\"\"\n",
        "    Applies configured augmentations to the selected chains.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        selected_chains: The batch of samples to augment.\n",
        "        config: Training configuration containing augmentation settings.\n",
        "\n",
        "    Returns:\n",
        "        The augmented batch of samples.\n",
        "    \"\"\"\n",
        "    augmented_chains = selected_chains\n",
        "    # Split key for independent random operations (augmentation, perturbation)\n",
        "    key_aug, key_perturb = jax.random.split(key)\n",
        "\n",
        "    if config.training.perturb:\n",
        "        # Removed redundant split: key, subkey = jax.random.split(key)\n",
        "        noise = (\n",
        "            jax.random.normal(\n",
        "                key_perturb,\n",
        "                augmented_chains.shape,\n",
        "                dtype=config.mp_policy.output_dtype,  # Use key_perturb\n",
        "            )\n",
        "            * config.training.perturbation_scale\n",
        "        )\n",
        "        augmented_chains = augmented_chains + noise\n",
        "\n",
        "    return augmented_chains  # Return only augmented chains\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 135,
      "metadata": {
        "id": "OiykQZ77LFv9"
      },
      "outputs": [],
      "source": [
        "def _compute_lambda_factor(\n",
        "    global_step: int,\n",
        "    lambda_total_steps: int,\n",
        "    lambda_max: float,\n",
        "    # lambda_epochs: int # Removed as redundant if lambda_total_steps is correct\n",
        ") -> Float[Array, \"\"]:\n",
        "    \"\"\"Calculates the lambda factor based on the current training progress.\"\"\"\n",
        "    progress_ratio = jnp.minimum(1.0, global_step / lambda_total_steps)\n",
        "    # Exponential growth from almost 0 to lambda_max\n",
        "    return jnp.array(\n",
        "        lambda_max * (1.0 - jnp.exp(-5.0 * progress_ratio)), dtype=jnp.float32\n",
        "    )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 136,
      "metadata": {
        "id": "FHrxn56NLtKd"
      },
      "outputs": [],
      "source": [
        "def control_variate_epsilon(\n",
        "    v_theta: Callable[[Float[Array, \"dim\"], float, Optional[float]], Float[Array, \"dim\"]],\n",
        "    x: Float[Array, \"dim\"],\n",
        "    t: float,\n",
        "    score_fn: Callable[[Float[Array, \"dim\"], float], Float[Array, \"dim\"]],\n",
        "    d: Optional[Float[Array, \"\"]] = None,\n",
        ") -> Float[Array, \"\"]:\n",
        "    \"\"\"Use control variate to reduce variance of the normalizing constant estimate.\n",
        "\n",
        "    Args:\n",
        "        v_theta: The velocity field function taking (x, t) and returning velocity vector\n",
        "        x: The point at which to compute the error\n",
        "        t: Current time\n",
        "        score_fn: Score function taking (x, t) and returning gradient of log density\n",
        "        d: Shortcut distance\n",
        "        use_hutchinson: Whether to use Hutchinson's trick\n",
        "        key: PRNG key for Hutchinson's trick\n",
        "        n_probes: Number of probes for Hutchinson's trick\n",
        "\n",
        "    Returns:\n",
        "        float: Local error in satisfying the Liouville equation\n",
        "    \"\"\"\n",
        "    # Calculate divergence using appropriate method\n",
        "    if d is not None:\n",
        "        div_v = divergence_velocity_with_shortcut(v_theta, x, t, d=d)\n",
        "        v = v_theta(x, t, d)\n",
        "    else:\n",
        "        div_v = divergence_velocity(v_theta, x, t)\n",
        "        v = v_theta(x, t)\n",
        "\n",
        "    # Get score and calculate dot product with better numerical stability\n",
        "    score = score_fn(x, t)\n",
        "    v_dot_score = jnp.sum(v * score)  # element-wise multiply then sum is more stable\n",
        "\n",
        "    # Calculate final result and handle NaN/inf values\n",
        "    result = div_v + v_dot_score\n",
        "\n",
        "    return jnp.nan_to_num(\n",
        "        result,\n",
        "        nan=0.0,\n",
        "        posinf=1.0,\n",
        "        neginf=-1.0,\n",
        "    )\n",
        "\n",
        "\n",
        "# vmap over particles, keeping other args static\n",
        "batched_control_variate_epsilon = jax.vmap(\n",
        "    control_variate_epsilon,\n",
        "    in_axes=(None, 0, None, None, None) # v_theta, x, t, score_fn, d, use_hutchinson, key (per-particle if needed), n_probes\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 137,
      "metadata": {
        "id": "9GG_fSqALkIe"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "def estimate_log_Z_t(\n",
        "    xs: Float[Array, \"time num_particles dim\"],\n",
        "    weights: Float[Array, \"time num_particles\"],\n",
        "    ts: Float[Array, \"time\"],\n",
        "    time_derivative_log_density: Callable[[Float[Array, \"dim\"], float], float],\n",
        "    v_theta: Optional[Callable[[Float[Array, \"dim\"], float, Optional[float]], Float[Array, \"dim\"]]] = None,\n",
        "    score_fn: Optional[Callable[[Float[Array, \"dim\"], float], Float[Array, \"dim\"]]] = None,\n",
        "    use_control_variate: bool = False,\n",
        "    use_shortcut: bool = False,\n",
        "    key: Optional[PRNGKeyArray] = None,\n",
        "    # --- Add Sharding Specs ---\n",
        "    param_replication: Optional[jshard.NamedSharding] = None,\n",
        "    xs_sharding_spec: Optional[jshard.NamedSharding] = None,\n",
        "    weights_sharding_spec: Optional[jshard.NamedSharding] = None,\n",
        ") -> Float[Array, \"1\"]:\n",
        "    \"\"\"Estimate log Z using jax.lax.scan, handling varying step sizes for shortcut CV.\n",
        "\n",
        "    Args:\n",
        "        xs: Samples from the distribution (time, num_particles, dim)\n",
        "        weights: Importance weights for the samples (time, num_particles)\n",
        "        ts: Time points (time,). Assumed sorted.\n",
        "        time_derivative_log_density: Function computing time derivative of log density\n",
        "        v_theta: Velocity field function\n",
        "        score_fn: Score function\n",
        "        use_control_variate: Whether to use control variate\n",
        "        use_shortcut: Whether to use shortcut distance (calculates 'd' based on ts diffs)\n",
        "        use_hutchinson: Whether to use Hutchinson estimator for divergence\n",
        "        key: PRNG key for Hutchinson estimator\n",
        "        n_probes: Number of probes for Hutchinson estimator\n",
        "\n",
        "    Returns:\n",
        "        Estimate of log partition function (scalar, shape (1,))\n",
        "    \"\"\"\n",
        "    # --- Apply Sharding ---\n",
        "    if xs_sharding_spec is not None and weights_sharding_spec is not None:\n",
        "        xs = eqx.filter_shard(xs, xs_sharding_spec)\n",
        "        weights = eqx.filter_shard(weights, weights_sharding_spec)\n",
        "        # ts = eqx.filter_shard(ts, jshard.NamedSharding(data_sharding_spec.mesh, jshard.PartitionSpec(None))) # Replicate if needed\n",
        "\n",
        "    if param_replication is not None and v_theta is not None:\n",
        "        v_theta = eqx.filter_shard(v_theta, param_replication)\n",
        "\n",
        "    # --- Prepare for Scan ---\n",
        "    if use_control_variate:\n",
        "        if v_theta is None or score_fn is None:\n",
        "            raise ValueError(\"v_theta and score_fn must be provided when use_control_variate is True.\")\n",
        "\n",
        "    # --- Precompute step durations (dts) if needed ---\n",
        "    dts_full = None # Will hold step durations if use_shortcut is True\n",
        "    if use_control_variate and use_shortcut:\n",
        "        if ts.shape[0] < 2:\n",
        "             # If only one time point, dt is arguably 0.\n",
        "             dts_full = jnp.zeros((1,)) if ts.shape[0] == 1 else jnp.zeros((0,))\n",
        "             # Warning: Shortcut CV might not be meaningful with < 2 points.\n",
        "        else:\n",
        "            # Calculate differences: dt[i] = ts[i+1] - ts[i]\n",
        "            dts = jnp.diff(ts, ) # Shape: (time-1,)\n",
        "\n",
        "            # Define dt for the first step (t=ts[0]). Common choices:\n",
        "            # 1. Use dts[0] (duration of first interval ts[1]-ts[0])\n",
        "            # 2. Use 0.0 (if ts[0] is pre-step)\n",
        "            # Let's use dts[0] assuming the calc at ts[i] relates to interval ending at ts[i].\n",
        "            final_dts = dts[-1]\n",
        "            # first_dt = 0.0 # Alternative\n",
        "\n",
        "            # Create full dt array matching length of ts\n",
        "            dts_full = jnp.concatenate([dts, jnp.array([final_dts])]) # Shape: (time,)\n",
        "\n",
        "    # --- Define the Scan Body ---\n",
        "    def scan_body(carry, scan_slice):\n",
        "        current_key = carry\n",
        "\n",
        "        # Unpack slice - depends on whether dts_full is included\n",
        "        if use_control_variate and use_shortcut:\n",
        "            x_t, w_t, t, dt_t = scan_slice # dt_t is the precomputed step duration for this step\n",
        "            d_val_step = dt_t\n",
        "        else:\n",
        "            x_t, w_t, t = scan_slice\n",
        "            d_val_step = None # d_val is not used\n",
        "\n",
        "        # 1. Calculate base term (vmap ONLY over particles)\n",
        "        dt_log_density_t = jax.vmap(lambda x: time_derivative_log_density(x, t))(x_t)\n",
        "\n",
        "        combined_term_t = dt_log_density_t\n",
        "        next_key = jax.random.fold_in(current_key, t) # Update key with time for reproducibility\n",
        "\n",
        "        # 2. Calculate control variate term if needed\n",
        "        if use_control_variate:\n",
        "            eps_t = batched_control_variate_epsilon(\n",
        "                v_theta, x_t, t, score_fn,\n",
        "                d_val_step, # Use the specific dt for this step\n",
        "            )\n",
        "            combined_term_t = combined_term_t + eps_t\n",
        "\n",
        "        # 3. Compute weighted sum for this step\n",
        "        step_contribution = jnp.sum(combined_term_t * w_t)\n",
        "\n",
        "        # 5. Return updated carry and optional per-step output\n",
        "        new_carry = next_key\n",
        "        per_step_output = step_contribution\n",
        "        return new_carry, per_step_output\n",
        "\n",
        "    # --- Run the Scan ---\n",
        "    initial_carry = key\n",
        "\n",
        "    # Prepare inputs for scan, including dts_full if computed\n",
        "    if use_control_variate and use_shortcut:\n",
        "        scan_inputs = (xs, weights, ts, dts_full)\n",
        "    else:\n",
        "        scan_inputs = (xs, weights, ts)\n",
        "\n",
        "    # Execute the scan\n",
        "    _, final_log_Z_estimate = jax.lax.scan(scan_body, initial_carry, scan_inputs)\n",
        "\n",
        "    # Reshape and clean up\n",
        "    final_result = final_log_Z_estimate.flatten()\n",
        "    return jnp.nan_to_num(final_result, posinf=1.0, neginf=-1.0)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 138,
      "metadata": {
        "id": "jsL_Dvq5LHq4"
      },
      "outputs": [],
      "source": [
        "def _maybe_estimate_log_z(\n",
        "    key: jax.random.PRNGKey,\n",
        "    epoch: int,\n",
        "    v_theta: PyTree,\n",
        "    current_ts: Float[Array, \" time\"],\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    config: TrainingExperimentConfig,\n",
        "    log_Z_t_ref: List[Optional[Float[Array, \" time\"]]],\n",
        "    current_lambda: Float[Array, \"\"],\n",
        "    # --- Add Sharding Specs ---\n",
        "    param_replication: Optional[jshard.NamedSharding] = None,\n",
        "    xs_sharding_spec: Optional[\n",
        "        jshard.NamedSharding\n",
        "    ] = None,  # Spec for particle dimension sharding\n",
        "    weights_sharding_spec: Optional[\n",
        "        jshard.NamedSharding\n",
        "    ] = None,  # Spec for weights sharding\n",
        ") -> Tuple[\n",
        "    Float[Array, \" time\"], Optional[Dict[str, Any]]\n",
        "]:  # Remove key from return tuple\n",
        "    \"\"\"\n",
        "    Estimates log_Z_t if required for the current epoch, otherwise reuses the previous value.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        epoch: Current epoch number.\n",
        "        v_theta: Current velocity field model.\n",
        "        current_ts: Current time steps.\n",
        "        path_distribution: Annealed distribution.\n",
        "        config: Training configuration.\n",
        "        log_Z_t_ref: Mutable list containing the last estimated log_Z_t.\n",
        "        current_lambda: Current lambda factor for MCMC.\n",
        "\n",
        "    Returns:\n",
        "        A tuple containing:\n",
        "            - The estimated or reused log_Z_t.\n",
        "            - The MCMC samples generated during estimation (or None if reused).\n",
        "    \"\"\"\n",
        "    mcmc_samples = None  # Initialize\n",
        "    should_estimate = (\n",
        "        (epoch % config.training.log_z_estimation_frequency == 0)\n",
        "        or (epoch == 0)\n",
        "        or (log_Z_t_ref[0] is None)\n",
        "    )\n",
        "\n",
        "    ts_for_estimation = current_ts  # Start with current ts\n",
        "\n",
        "    if should_estimate:\n",
        "        print(f\"Epoch {epoch}: Estimating log Z(t)...\")\n",
        "        # key, subkey_sample = jax.random.split(key) # Will use the passed key directly\n",
        "\n",
        "        mcmc_samples = generate_samples_with_optional_mcmc(\n",
        "            key,\n",
        "            v_theta,\n",
        "            ts_for_estimation,\n",
        "            path_distribution,\n",
        "            config,  # Pass the key directly\n",
        "            mcmc_method=config.mcmc.method,\n",
        "            force_finite=True,\n",
        "            lambda_factor=current_lambda,\n",
        "        )\n",
        "\n",
        "        # Estimate Log Z using the determined time steps (ts_for_estimation)\n",
        "        print(\n",
        "            f\"Using standard log Z estimation (Control Variate: {config.mcmc.use_control_variate}).\"\n",
        "        )\n",
        "        log_Z_t = estimate_log_Z_t(\n",
        "            mcmc_samples[\"positions\"],\n",
        "            mcmc_samples[\"weights\"],\n",
        "            ts_for_estimation,  # Use potentially updated ts\n",
        "            path_distribution.time_derivative,\n",
        "            v_theta=v_theta,\n",
        "            score_fn=path_distribution.score_fn,\n",
        "            use_control_variate=config.mcmc.use_control_variate,\n",
        "            use_shortcut=config.training.use_shortcut,\n",
        "            key=key,  # Use the key available in this scope\n",
        "            # --- Pass Sharding Specs ---\n",
        "            param_replication=param_replication,\n",
        "            xs_sharding_spec=xs_sharding_spec,\n",
        "            weights_sharding_spec=weights_sharding_spec,\n",
        "        )\n",
        "\n",
        "        log_Z_t_ref[0] = log_Z_t  # Update the shared reference\n",
        "\n",
        "        # Logging\n",
        "        log_Z_t_to_log = jnp.nan_to_num(log_Z_t, nan=0.0, posinf=1.0, neginf=-1.0)\n",
        "\n",
        "        print(f\"Epoch {epoch}, Log Z(t) estimated (shape {log_Z_t.shape})\")\n",
        "        if \"ess\" in mcmc_samples:\n",
        "            print(f\"Epoch {epoch}, MCMC Samples ESS: {mcmc_samples['ess']}\")\n",
        "\n",
        "    else:\n",
        "        # Reuse the previous estimation\n",
        "        log_Z_t = log_Z_t_ref[0]\n",
        "        if log_Z_t is None:\n",
        "            # This should not happen if estimation occurs at epoch 0\n",
        "            raise RuntimeError(\n",
        "                \"log_Z_t_ref[0] is None, but should have been estimated in epoch 0.\"\n",
        "            )\n",
        "        print(f\"Epoch {epoch}: Reusing previous log Z(t) estimation.\")\n",
        "\n",
        "    # Return the potentially updated current_ts (especially if ASMC was used)\n",
        "    # Removed key from return tuple\n",
        "    return log_Z_t, mcmc_samples"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 139,
      "metadata": {
        "id": "2SaIoKlBMQfs"
      },
      "outputs": [],
      "source": [
        "def _prepare_epoch_samples(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: PyTree,\n",
        "    current_ts: Float[Array, \" time\"],\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    config: TrainingExperimentConfig,\n",
        "    mcmc_samples_from_logz: Optional[\n",
        "        Dict[str, Any]\n",
        "    ],  # Samples from log Z estimation step\n",
        "    current_lambda: Float[Array, \"\"],\n",
        ") -> Float[Array, \"time batch dim\"]:  # Removed key from return tuple\n",
        "    \"\"\"\n",
        "    Prepares the pool of samples for the epoch's training steps based on config.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        v_theta: Current velocity field model.\n",
        "        current_ts: Current time steps.\n",
        "        path_distribution: Annealed distribution.\n",
        "        config: Training configuration.\n",
        "        mcmc_samples_from_logz: Samples generated during log Z estimation (if any).\n",
        "        current_lambda: Current lambda factor for MCMC.\n",
        "\n",
        "    Returns:\n",
        "        The prepared pool of samples for the epoch's training steps.\n",
        "    \"\"\"\n",
        "    base_mcmc_samples: Dict[str, Any]\n",
        "    # Split key for potentially independent sampling operations\n",
        "    key_base_mcmc, key_v_theta, key_random = jax.random.split(key, 3)\n",
        "\n",
        "    _need_mcmc_samples = (\n",
        "        config.training.training_data == \"combined\"\n",
        "        or config.training.training_data == \"vsmc\"\n",
        "    )\n",
        "\n",
        "    if _need_mcmc_samples:\n",
        "        if mcmc_samples_from_logz is None:\n",
        "            # This happens if log Z wasn't estimated this epoch. Need to generate base samples.\n",
        "            # Use the standard MCMC method defined in config for generating training samples.\n",
        "            print(\n",
        "                \"Epoch: Generating base MCMC samples as none were provided (log Z reused).\"\n",
        "            )\n",
        "            # key, subkey = jax.random.split(key) # Remove redundant split\n",
        "            base_mcmc_samples = generate_samples_with_optional_mcmc(\n",
        "                key_base_mcmc,\n",
        "                v_theta,\n",
        "                current_ts,\n",
        "                path_distribution,\n",
        "                config,  # Use key_base_mcmc\n",
        "                mcmc_method=config.mcmc.method,\n",
        "                force_finite=True,\n",
        "                lambda_factor=current_lambda,\n",
        "            )\n",
        "        else:\n",
        "            # Reuse samples generated during log Z estimation\n",
        "            print(\"Epoch: Reusing MCMC samples generated during log Z estimation.\")\n",
        "            base_mcmc_samples = mcmc_samples_from_logz\n",
        "\n",
        "    if config.training.training_data == \"vsmc\":\n",
        "        return base_mcmc_samples[\"positions\"]  # Remove key from return\n",
        "    elif config.training.training_data == \"combined\":\n",
        "        print(\"Epoch: Generating additional samples for combined training data.\")\n",
        "        # key, subkey = jax.random.split(key) # Remove redundant split\n",
        "        # Generate samples using only the velocity field (no MCMC correction)\n",
        "        v_theta_samples_dict = generate_samples_with_optional_mcmc(\n",
        "            key_v_theta,\n",
        "            v_theta,\n",
        "            current_ts,\n",
        "            path_distribution,\n",
        "            config,  # Use key_v_theta\n",
        "            mcmc_method=\"none\",\n",
        "            force_finite=True,\n",
        "            lambda_factor=current_lambda,\n",
        "            # Ensure same number of particles as base_mcmc_samples\n",
        "            num_samples=base_mcmc_samples[\"positions\"].shape[1],\n",
        "        )\n",
        "        v_theta_samples = v_theta_samples_dict[\"positions\"]\n",
        "\n",
        "        # Concatenate along the particle batch dimension\n",
        "        samples = jnp.concatenate(\n",
        "            [base_mcmc_samples[\"positions\"], v_theta_samples], axis=1\n",
        "        )\n",
        "        print(\n",
        "            f\"Epoch: Concatenated samples for combined training data. New shape: {samples.shape}\"\n",
        "        )\n",
        "\n",
        "        return samples  # Remove key from return\n",
        "    elif config.training.training_data == \"random\":\n",
        "        print(\"Epoch: Generating random samples for training.\")\n",
        "        # key, subkey = jax.random.split(key) # Remove redundant split\n",
        "        # Generate random samples\n",
        "        random_samples = path_distribution.sample_initial(\n",
        "            key_random, (config.sampling.num_timesteps, config.sampling.num_particles)\n",
        "        )  # Use key_random\n",
        "        random_samples = random_samples.reshape(\n",
        "            config.sampling.num_timesteps, config.sampling.num_particles, -1\n",
        "        )\n",
        "        return random_samples  # Remove key from return"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 140,
      "metadata": {
        "id": "aBKjG4DDMdN-"
      },
      "outputs": [],
      "source": [
        "def _run_steps_for_epoch(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: PyTree,\n",
        "    opt_state: PyTree,\n",
        "    optimizer: optax.GradientTransformation,\n",
        "    lr_schedule_fn: Callable[[int], float],  # Pass the schedule function\n",
        "    epoch: int,\n",
        "    initial_optimizer_step_count: int,  # Added: Track optimizer updates across epochs\n",
        "    samples: Float[Array, \"time batch dim\"],\n",
        "    current_ts: Float[Array, \" time\"],\n",
        "    log_Z_t: Float[Array, \" time\"],\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    config: TrainingExperimentConfig,\n",
        "    # sampler parameter removed\n",
        "    # Sharding specs (passed down):\n",
        "    param_replication: Optional[jshard.NamedSharding] = None,\n",
        "    data_sharding_rank1: Optional[jshard.NamedSharding] = None,\n",
        "    data_sharding_rank2: Optional[jshard.NamedSharding] = None,\n",
        "    probe_sharding_rank2: Optional[jshard.NamedSharding] = None,\n",
        "    probe_sharding_rank3: Optional[jshard.NamedSharding] = None,\n",
        ") -> Tuple[\n",
        "    PyTree, PyTree, Float[Array, \"\"], int\n",
        "]:  # Return signature updated (key removed, added optimizer step count)\n",
        "    \"\"\"\n",
        "    Runs all training steps within a single epoch, potentially including RAD sampling.\n",
        "    Tracks and returns the total number of optimizer steps performed.\n",
        "    Logs steps using the optimizer step count for 'global_step'.\n",
        "    \"\"\"\n",
        "    epoch_loss = 0.0\n",
        "    steps_per_epoch = config.training.steps_per_epoch\n",
        "    current_optimizer_step_count = (\n",
        "        initial_optimizer_step_count  # Initialize counter for this epoch\n",
        "    )\n",
        "\n",
        "    # Pre-compile static arguments for the JITted step function\n",
        "    static_args = (\n",
        "        optimizer,\n",
        "        path_distribution.time_derivative,\n",
        "        path_distribution.score_fn,\n",
        "        config.density.shift_fn,\n",
        "        config.training.estimator,\n",
        "        config.training.n_probes,\n",
        "        config.training.use_combined_loss,\n",
        "        config.training.shortcut_weight,\n",
        "        config.training.random_alpha,  # Re-added\n",
        "        config.model.dropout,\n",
        "        config.training.skip_shortcut,\n",
        "        config.training.estimate_logz,\n",
        "        config.training.learnable_path,\n",
        "        # Pass sharding specs as static args\n",
        "        param_replication,\n",
        "        data_sharding_rank1,\n",
        "        data_sharding_rank2,\n",
        "        probe_sharding_rank2,\n",
        "        probe_sharding_rank3,\n",
        "    )\n",
        "\n",
        "    for s in range(steps_per_epoch):\n",
        "        # global_step_count = epoch * steps_per_epoch + s # Replaced by current_optimizer_step_count for logging 'global_step'\n",
        "        outer_step_index = (\n",
        "            epoch * steps_per_epoch + s\n",
        "        )  # Keep track of outer step for LR logging consistency\n",
        "        # Split key for the operations within this step: batch prep, aug, initial step exec, RAD loop\n",
        "        key_step, key_batch, key_aug, key_init_step, key_rad_loop = jax.random.split(\n",
        "            key, 5\n",
        "        )\n",
        "        key = key_step  # Consume the key for the outer loop\n",
        "\n",
        "        # 1. Prepare Initial Batch (sampler argument removed)\n",
        "        training_particles_pre_aug = _prepare_step_batch(\n",
        "            key_batch, samples, current_ts, log_Z_t, config\n",
        "        )\n",
        "\n",
        "        # 2. Apply Augmentations to Initial Batch\n",
        "        augmented_chains = _apply_augmentations(\n",
        "            key_aug, training_particles_pre_aug.x, config\n",
        "        )\n",
        "\n",
        "        # Create the initial Particle object for the step\n",
        "        initial_particles = Particle(\n",
        "            x=augmented_chains,\n",
        "            t=training_particles_pre_aug.t,\n",
        "            log_Z_t=training_particles_pre_aug.log_Z_t,\n",
        "            d=training_particles_pre_aug.d,\n",
        "            loss_weight=training_particles_pre_aug.loss_weight,  # Keep loss weight if used\n",
        "        )\n",
        "\n",
        "        # 3. Execute JITted Training Step\n",
        "        # (Conditional LBFGS logic removed as it's handled by main.py routing)\n",
        "        v_theta, opt_state, initial_loss, current_epsilons = _execute_jitted_step(\n",
        "            key_init_step, v_theta, opt_state, initial_particles, *static_args\n",
        "        )\n",
        "        current_optimizer_step_count += 1  # Increment after initial step update\n",
        "\n",
        "        # 4. Accumulate Initial Loss (Only this loss contributes to epoch average)\n",
        "        epoch_loss += initial_loss\n",
        "\n",
        "        # --- 5. RARD Sampling Loop (Conditional) ---\n",
        "        if config.training.use_rad_sampling and config.training.rad_steps > 0:\n",
        "            particles_for_resampling = (\n",
        "                initial_particles  # Start with particles from the initial step\n",
        "            )\n",
        "\n",
        "            # Split key for the entire RARD loop's iterations\n",
        "            keys_rad_iters = jax.random.split(key_rad_loop, config.training.rad_steps)\n",
        "\n",
        "            for rad_iter in range(config.training.rad_steps):\n",
        "                # Split key for this specific RARD iteration's operations\n",
        "                key_rad_iter_local = keys_rad_iters[rad_iter]\n",
        "                key_rad_sample, key_rad_aug, key_rad_step_exec = jax.random.split(\n",
        "                    key_rad_iter_local, 3\n",
        "                )\n",
        "\n",
        "                # 5.1. RAD Resample using epsilons from the *previous* step\n",
        "                # Pass path_distribution only if weights are needed\n",
        "                path_dist_for_rad = (\n",
        "                    path_distribution if config.training.rad_use_weights else None\n",
        "                )\n",
        "                resampled_parts_pre_aug, rad_weights = rad_sampler_joint(\n",
        "                    key=key_rad_sample,\n",
        "                    particles=particles_for_resampling,\n",
        "                    epsilons=jnp.abs(\n",
        "                        current_epsilons\n",
        "                    ),  # Use abs(epsilons) for probability calculation\n",
        "                    beta=config.training.rad_beta,\n",
        "                    path_dist=path_dist_for_rad,\n",
        "                    use_weights=config.training.rad_use_weights,\n",
        "                )\n",
        "\n",
        "                # 5.2. Augment the Resampled Batch\n",
        "                augmented_rad_chains = _apply_augmentations(\n",
        "                    key_rad_aug, resampled_parts_pre_aug.x, config\n",
        "                )\n",
        "\n",
        "                # 5.3. Prepare Particle Object for RAD Step Execution\n",
        "                # Create new Particle with augmented chains\n",
        "                rad_step_particles = Particle(\n",
        "                    x=augmented_rad_chains,\n",
        "                    t=resampled_parts_pre_aug.t,\n",
        "                    log_Z_t=resampled_parts_pre_aug.log_Z_t,\n",
        "                    d=resampled_parts_pre_aug.d,  # Keep d if present\n",
        "                    loss_weight=rad_weights\n",
        "                    if config.training.rad_use_weights\n",
        "                    else None,  # Initialize loss_weight, will be updated below if needed\n",
        "                )\n",
        "\n",
        "                # 5.4. Execute Training Step (Backpropagation happens here\n",
        "                # Call the standard JITted step function for RAD step\n",
        "                v_theta, opt_state, rad_step_loss, current_epsilons = (\n",
        "                    _execute_jitted_step(\n",
        "                        key_rad_step_exec,\n",
        "                        v_theta,\n",
        "                        opt_state,\n",
        "                        rad_step_particles,\n",
        "                        *static_args,\n",
        "                    )\n",
        "                )\n",
        "                current_optimizer_step_count += 1  # Increment after RAD step update\n",
        "\n",
        "                # Optional: Log rad_step_loss here if needed for debugging, e.g.,\n",
        "                # Print the RAD step loss\n",
        "                rad_lr = lr_schedule_fn(current_optimizer_step_count - 1)\n",
        "                print(\n",
        "                    f\"Epoch {epoch}, Step {s}, RAD Iter {rad_iter}, Opt Step {current_optimizer_step_count}, Loss: {rad_step_loss:.4f}, LR: {rad_lr:.6f}\"\n",
        "                )\n",
        "\n",
        "                # Update the particles state for the next resampling iteration\n",
        "                particles_for_resampling = rad_step_particles\n",
        "                # current_epsilons is now updated for the next RAD loop or end of loop\n",
        "\n",
        "        # --- End RAD Sampling Loop ---\n",
        "\n",
        "        # 6. Log Step Loss (periodically, based on the initial step's loss)\n",
        "        if s % 20 == 0:  # Log every 20 outer steps\n",
        "            # Log the initial_loss for this step 's'. Use LR corresponding to the start of the outer step for consistency.\n",
        "            # Use the optimizer step count for the 'global_step' field.\n",
        "            lr_for_outer_step = lr_schedule_fn(\n",
        "                outer_step_index\n",
        "            )  # LR at start of outer step s\n",
        "            step_metrics = {\n",
        "                \"loss\": initial_loss,\n",
        "                \"learning_rate\": lr_for_outer_step,  # Log LR associated with the outer step start\n",
        "                \"epoch\": epoch,\n",
        "                \"step\": s,  # Outer step index\n",
        "                \"global_step\": current_optimizer_step_count,  # Log the latest optimizer step count\n",
        "            }\n",
        "\n",
        "            # Print Opt Step count which includes RAD steps if any\n",
        "            print(\n",
        "                f\"Epoch {epoch}, Step {s}, Opt Step {current_optimizer_step_count}, Loss: {initial_loss:.4f}, LR: {lr_for_outer_step:.6f}\"\n",
        "            )\n",
        "\n",
        "    # Calculate average epoch loss based ONLY on the initial steps' losses\n",
        "    avg_epoch_loss = epoch_loss / steps_per_epoch\n",
        "    # Return updated state and the final optimizer step count for this epoch\n",
        "    return v_theta, opt_state, avg_epoch_loss, current_optimizer_step_count"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 141,
      "metadata": {
        "id": "mg8lEjfzM77v"
      },
      "outputs": [],
      "source": [
        "def _calculate_and_log_epoch_metrics(\n",
        "    epoch: int,\n",
        "    avg_train_loss: Float[Array, \"\"],\n",
        "    final_optimizer_step_count_for_epoch: int,  # Added: Use final optimizer step count\n",
        "    v_theta: PyTree,\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    config: TrainingExperimentConfig,\n",
        "    lr_schedule_fn: Callable[[int], float],  # Pass the schedule function\n",
        ") -> None:\n",
        "    \"\"\"Calculates and logs epoch summary metrics using optimizer step count.\"\"\"\n",
        "    # Calculate LR at the *start* of the epoch based on outer steps for consistent reporting\n",
        "    global_step_start_epoch = epoch * config.training.steps_per_epoch\n",
        "\n",
        "    # Validation loss calculation removed.\n",
        "\n",
        "    epoch_lr = lr_schedule_fn(\n",
        "        global_step_start_epoch\n",
        "    )  # LR at start of epoch (based on outer steps)\n",
        "    epoch_metrics = {\n",
        "        \"epoch\": epoch,\n",
        "        \"average_train_loss\": avg_train_loss,\n",
        "        # \"validation_loss_epoch_end\": val_loss, # Removed\n",
        "        \"epoch_learning_rate\": epoch_lr,  # Log LR at start of epoch\n",
        "        \"global_step\": final_optimizer_step_count_for_epoch,  # Log final optimizer step count for epoch\n",
        "    }\n",
        "\n",
        "    print(f\"--- Epoch {epoch} Summary ---\")\n",
        "    print(f\"  Avg Train Loss: {avg_train_loss:.4f}\")\n",
        "    # print(f\"  Validation Loss: {val_loss:.4f}\") # Removed\n",
        "    print(\n",
        "        f\"  Learning Rate at Epoch Start (Outer Step {global_step_start_epoch}): {epoch_lr:.6f}\"\n",
        "    )\n",
        "    print(f\"  Optimizer Steps Completed: {final_optimizer_step_count_for_epoch}\")\n",
        "    print(\"----------------------\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 142,
      "metadata": {
        "id": "b1qIsBZgNQT5"
      },
      "outputs": [],
      "source": [
        "def evaluate_model(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: Callable,\n",
        "    config: TrainingExperimentConfig,\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    target_density,\n",
        "    current_end_time: float,  # Interpreted as t_final\n",
        "    current_ts: Float[Array, \"t\"] = None,\n",
        ") -> Dict[str, Any]:\n",
        "    \"\"\"Run a single evaluation pass and return metrics.\"\"\"\n",
        "    total_eval_metrics = {}\n",
        "    t_final = current_end_time  # Use a clearer variable name\n",
        "\n",
        "    if config.training.use_shortcut:\n",
        "        eval_steps_config_list = config.training.shortcut_size\n",
        "\n",
        "        for es in eval_steps_config_list:  # es is the number of steps from config\n",
        "            key, eval_key = jax.random.split(key)\n",
        "\n",
        "            # Determine the actual ts to use for this evaluation\n",
        "            ts_to_use_for_this_eval = current_ts\n",
        "            if ts_to_use_for_this_eval is None:\n",
        "                # Generate ts based on es and t_final\n",
        "                ts_to_use_for_this_eval = (\n",
        "                    jnp.linspace(0, t_final, es) if es > 1 else jnp.array([0.0, t_final])\n",
        "                )\n",
        "\n",
        "            eval_metrics = target_density.evaluate(\n",
        "                eval_key,\n",
        "                use_shortcut=True,  # This is the use_shortcut=True branch\n",
        "                ts=ts_to_use_for_this_eval,\n",
        "                v_theta=v_theta,\n",
        "                base_density=path_distribution.initial_density,\n",
        "                # estimate_dt_logZ is typically not passed or False for shortcut\n",
        "            )\n",
        "            # The key uses 'es' from the loop, consistent with original if current_ts overrides.\n",
        "            total_eval_metrics[f\"validation_{es}_step\"] = eval_metrics\n",
        "    else:\n",
        "        # Not config.training.use_shortcut (model not trained with shortcut)\n",
        "        if current_ts is not None:\n",
        "            # If current_ts is provided, evaluate once with it. This takes precedence.\n",
        "            key, eval_key = jax.random.split(key)\n",
        "            eval_metrics = target_density.evaluate(\n",
        "                eval_key,\n",
        "                use_shortcut=False,  # Explicitly False\n",
        "                ts=current_ts,\n",
        "                v_theta=v_theta,\n",
        "                base_density=path_distribution.initial_density,\n",
        "                estimate_dt_logZ=config.training.estimate_logz,\n",
        "            )\n",
        "            # Use a distinct key for this custom_ts evaluation\n",
        "            total_eval_metrics[\n",
        "                f\"validation_{len(current_ts)}_step_custom_ts\"\n",
        "            ] = eval_metrics\n",
        "        else:\n",
        "            # current_ts is None. Evaluate with various numbers of steps.\n",
        "            # Use config.training.shortcut_size as the list of step counts for evaluation.\n",
        "            eval_steps_list_for_loop = config.training.shortcut_size\n",
        "\n",
        "            if not eval_steps_list_for_loop:\n",
        "                # Fallback: If shortcut_size is empty (or not intended for this use),\n",
        "                # perform a single evaluation using config.sampling.num_timesteps.\n",
        "                # This corrects the original linspace generation for this case.\n",
        "                num_default_eval_steps = config.sampling.num_timesteps\n",
        "                ts_for_default_eval = jnp.linspace(\n",
        "                    0, t_final, num_default_eval_steps\n",
        "                )\n",
        "\n",
        "                key, eval_key = jax.random.split(key)\n",
        "                eval_metrics = target_density.evaluate(\n",
        "                    eval_key,\n",
        "                    use_shortcut=False,\n",
        "                    ts=ts_for_default_eval,\n",
        "                    v_theta=v_theta,\n",
        "                    base_density=path_distribution.initial_density,\n",
        "                    estimate_dt_logZ=config.training.estimate_logz,\n",
        "                )\n",
        "                total_eval_metrics[\n",
        "                    f\"validation_{num_default_eval_steps}_step\"\n",
        "                ] = eval_metrics\n",
        "            else:\n",
        "                # Loop through the specified eval_steps_list_for_loop\n",
        "                # (from config.training.shortcut_size)\n",
        "                for es_loop_val in eval_steps_list_for_loop:\n",
        "                    key, eval_key = jax.random.split(key)\n",
        "                    ts_for_this_eval_step = (\n",
        "                        jnp.linspace(0, t_final, es_loop_val)\n",
        "                        if es_loop_val > 1\n",
        "                        else jnp.array([0.0, t_final])\n",
        "                    )\n",
        "\n",
        "                    eval_metrics = target_density.evaluate(\n",
        "                        eval_key,\n",
        "                        use_shortcut=False,  # Explicitly False\n",
        "                        ts=ts_for_this_eval_step,\n",
        "                        v_theta=v_theta,\n",
        "                        base_density=path_distribution.initial_density,\n",
        "                        estimate_dt_logZ=config.training.estimate_logz,\n",
        "                    )\n",
        "                    total_eval_metrics[f\"validation_{es_loop_val}_step\"] = eval_metrics\n",
        "\n",
        "    return total_eval_metrics\n",
        "\n",
        "def aggregate_eval_metrics(\n",
        "    all_eval_results: List[Dict[str, Any]],\n",
        ") -> Dict[str, Dict[str, Any]]:\n",
        "    \"\"\"Aggregate metrics across multiple evaluation runs with proper figure cleanup.\"\"\"\n",
        "    aggregated_metrics = {}\n",
        "\n",
        "    for step_key in all_eval_results[0].keys():\n",
        "        step_metrics = [result[step_key] for result in all_eval_results]\n",
        "        agg_metrics = {}\n",
        "        figures = []\n",
        "\n",
        "        # Collect metrics and figures from all runs\n",
        "        for run_idx, metrics in enumerate(step_metrics):\n",
        "            # Collect figures for later cleanup\n",
        "            if \"figure\" in metrics:\n",
        "                figures.append(metrics[\"figure\"])\n",
        "\n",
        "            # Process numerical metrics\n",
        "            for metric_name in metrics.keys():\n",
        "                if metric_name == \"figure\":\n",
        "                    continue\n",
        "\n",
        "                # Initialize storage if first run\n",
        "                if run_idx == 0:\n",
        "                    agg_metrics[f\"{metric_name}_mean\"] = []\n",
        "                    agg_metrics[f\"{metric_name}_var\"] = []\n",
        "\n",
        "                # Collect values\n",
        "                agg_metrics[f\"{metric_name}_mean\"].append(metrics[metric_name])\n",
        "\n",
        "        # Close all but last figure\n",
        "        for fig in figures[:-1]:\n",
        "            plt.close(fig)\n",
        "\n",
        "        # Add last figure to metrics if exists\n",
        "        if figures:\n",
        "            agg_metrics[\"figure\"] = figures[-1]\n",
        "\n",
        "        # Calculate final mean/var for numerical metrics\n",
        "        for metric_name in list(agg_metrics.keys()):\n",
        "            if \"_mean\" in metric_name:\n",
        "                base_name = metric_name.replace(\"_mean\", \"\")\n",
        "                values_list = agg_metrics.pop(metric_name)\n",
        "                values = jnp.array(values_list)\n",
        "\n",
        "                # IQR Outlier Removal\n",
        "                if values.size > 3: # Need enough data points for IQR\n",
        "                    q1 = jnp.percentile(values, 25)\n",
        "                    q3 = jnp.percentile(values, 75)\n",
        "                    iqr = q3 - q1\n",
        "                    lower_bound = q1 - 1.5 * iqr\n",
        "                    upper_bound = q3 + 1.5 * iqr\n",
        "                    filtered_values = values[(values >= lower_bound) & (values <= upper_bound)]\n",
        "                    if filtered_values.size == 0: # All values were outliers, fall back to original\n",
        "                        filtered_values = values\n",
        "                else: # Not enough data, use all values\n",
        "                    filtered_values = values\n",
        "\n",
        "# Log the data being aggregated\n",
        "                # print(f\"DEBUG: Aggregating for step '{step_key}', metric '{base_name}': Data for mean/var = {filtered_values.tolist()}\")\n",
        "                agg_metrics[f\"{base_name}_mean\"] = jnp.mean(filtered_values)\n",
        "                agg_metrics[f\"{base_name}_var\"] = jnp.var(filtered_values)\n",
        "\n",
        "        aggregated_metrics[step_key] = agg_metrics\n",
        "\n",
        "    return aggregated_metrics\n",
        "\n",
        "\n",
        "def log_metrics(\n",
        "    aggregated_metrics: Dict[str, Dict[str, Any]], config: TrainingExperimentConfig\n",
        "):\n",
        "    \"\"\"Handle metric logging to appropriate destinations.\"\"\"\n",
        "    for step_key, metrics in aggregated_metrics.items():\n",
        "        figure = metrics.pop(\"figure\", None)\n",
        "        print(f\"Evaluation results for {step_key}:\")\n",
        "        for metric_name, value in metrics.items():\n",
        "            print(f\"{metric_name}: {value}\")\n",
        "        if figure is not None:\n",
        "            plt.show()\n",
        "            plt.close(figure)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 143,
      "metadata": {
        "id": "ggiCNa7QNBhz"
      },
      "outputs": [],
      "source": [
        "def _maybe_evaluate_and_save(\n",
        "    key: jax.random.PRNGKey,\n",
        "    epoch: int,\n",
        "    v_theta: PyTree,\n",
        "    config: TrainingExperimentConfig,\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    target_density,\n",
        "    # validation_particles: Particle, # Removed\n",
        "    # validation_ts: Float[Array, \" time\"], # Removed\n",
        "    best_metrics: List[Tuple[float, int]],  # List of (metric_value, version)\n",
        "    model_version: int,\n",
        "    # Sharding specs (passed down):\n",
        "    param_replication: Optional[jshard.NamedSharding] = None,\n",
        ") -> Tuple[List[Tuple[float, int]], int]:  # Remove key from return tuple\n",
        "    \"\"\"\n",
        "    Performs model evaluation and saves the best model periodically. Handles sharded model.\n",
        "\n",
        "    Args:\n",
        "        key: Random key for this operation (will be split internally).\n",
        "        epoch: Current epoch number.\n",
        "        v_theta: Current velocity field model.\n",
        "        config: Training configuration.\n",
        "        path_distribution: Annealed distribution.\n",
        "        target_density: Target density (needed for saving context).\n",
        "        best_metrics: List of best metric values and versions so far.\n",
        "        model_version: Current model version counter.\n",
        "        param_replication: Sharding spec for model parameters (None if single device).\n",
        "\n",
        "    Returns:\n",
        "        A tuple containing:\n",
        "            - The updated list of best metrics.\n",
        "            - The updated model version counter.\n",
        "    \"\"\"\n",
        "    if epoch % config.training.eval_frequency == 0:\n",
        "        print(f\"--- Epoch {epoch}: Running Evaluation ---\")\n",
        "        # Run multiple evaluations (original code had loop for 1 iteration)\n",
        "        all_eval_results = []\n",
        "        num_eval_runs = 1  # Make configurable\n",
        "        # Split the key once for all evaluation runs\n",
        "        keys_eval = jax.random.split(key, num_eval_runs)\n",
        "        for i in range(num_eval_runs):\n",
        "            # key, subkey = jax.random.split(key) # Removed per-run split\n",
        "            print(f\"  Evaluation Run {i + 1}/{num_eval_runs}...\")\n",
        "            eval_metrics = evaluate_model(\n",
        "                keys_eval[i],  # Use the dedicated subkey for this run\n",
        "                v_theta,\n",
        "                config,\n",
        "                path_distribution,\n",
        "                target_density,\n",
        "                current_end_time=1.0,  # Pass num_timesteps from config\n",
        "            )\n",
        "            all_eval_results.append(eval_metrics)\n",
        "\n",
        "        # Process and log metrics\n",
        "        aggregated_metrics = aggregate_eval_metrics(all_eval_results)\n",
        "        log_metrics(aggregated_metrics, config)  # Pass epoch for logging\n",
        "\n",
        "        # Validation loss curve block removed (lines 588-628)\n",
        "\n",
        "        print(f\"--- Epoch {epoch}: Evaluation Complete ---\")\n",
        "\n",
        "    # Return updated metrics and version, key is not propagated\n",
        "    return best_metrics, model_version"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 144,
      "metadata": {
        "id": "ZOprEfMzNpEU"
      },
      "outputs": [],
      "source": [
        "@eqx.filter_jit\n",
        "def sample_monotonic_uniform_ordered(\n",
        "    key: jax.random.PRNGKey, bounds: Float[Array, \" time\"], include_endpoints: bool = True\n",
        ") -> Float[Array, \" time\"]:\n",
        "    def step(carry, info):\n",
        "        t_prev = carry\n",
        "        t_current = info\n",
        "\n",
        "        return t_current, jnp.array([t_prev, t_current])\n",
        "\n",
        "    _, ordered_pairs = jax.lax.scan(step, bounds[0], bounds[1:])\n",
        "\n",
        "    if include_endpoints:\n",
        "        ordered_pairs = jnp.concatenate(\n",
        "            [ordered_pairs, jnp.array([[1.0, 1.0]])], axis=0\n",
        "        )\n",
        "\n",
        "    samples = jax.random.uniform(\n",
        "        key, bounds.shape, minval=ordered_pairs[:, 0], maxval=ordered_pairs[:, 1]\n",
        "    )\n",
        "    return samples\n",
        "\n",
        "def sample_continuous_time(\n",
        "    key: jax.random.PRNGKey, base_ts: Float[Array, \" time\"]\n",
        ") -> Float[Array, \" time\"]:\n",
        "    \"\"\"\n",
        "    Samples new time steps monotonically based on existing base time steps.\n",
        "\n",
        "    Args:\n",
        "        key: JAX random key.\n",
        "        base_ts: The base time steps (sorted array, e.g., from setup_time_schedule).\n",
        "\n",
        "    Returns:\n",
        "        A new array of sampled time steps, ordered monotonically, matching the size of base_ts.\n",
        "    \"\"\"\n",
        "    # Assuming include_endpoints=True is the desired behavior based on core.py usage\n",
        "    return sample_monotonic_uniform_ordered(key, base_ts, include_endpoints=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 145,
      "metadata": {
        "id": "d-zvU45wNcrc"
      },
      "outputs": [],
      "source": [
        "def _run_training_loop(\n",
        "    key: jax.random.PRNGKey,\n",
        "    v_theta: PyTree,\n",
        "    opt_state: PyTree,\n",
        "    optimizer: optax.GradientTransformation,\n",
        "    lr_schedule_fn: Callable[[int], float],\n",
        "    path_distribution: AnnealedDistribution,\n",
        "    target_density,  # Needed for evaluation saving\n",
        "    base_ts: Float[Array, \" time\"],\n",
        "    # validation_particles: Particle, # Removed\n",
        "    # validation_ts: Float[Array, \" time\"], # Removed\n",
        "    log_Z_t_ref: List[Optional[Float[Array, \" time\"]]],\n",
        "    best_metrics: List[Tuple[float, int]],\n",
        "    model_version: int,\n",
        "    config: TrainingExperimentConfig,\n",
        "    # Sharding specs (passed down):\n",
        "    param_replication: Optional[jshard.NamedSharding] = None,\n",
        "    data_sharding_rank1: Optional[jshard.NamedSharding] = None,\n",
        "    data_sharding_rank2: Optional[jshard.NamedSharding] = None,\n",
        "    probe_sharding_rank2: Optional[jshard.NamedSharding] = None,\n",
        "    probe_sharding_rank3: Optional[jshard.NamedSharding] = None,\n",
        "    xs_sharding_spec: Optional[\n",
        "        jshard.NamedSharding\n",
        "    ] = None,  # Spec for particle dimension sharding in log_Z_t estimation\n",
        ") -> Tuple[PyTree, List[Tuple[float, int]]]:\n",
        "    \"\"\"Runs the main training loop over all epochs.\"\"\"\n",
        "\n",
        "    current_ts = base_ts  # Initialize current_ts\n",
        "    # Persisted MCMC samples - Initialized outside loop now\n",
        "    lambda_max = config.mcmc.lambda_max\n",
        "    lambda_epochs = config.mcmc.lambda_epochs\n",
        "    # Ensure lambda_total_steps is at least 1 to avoid division by zero\n",
        "    # Lambda factor calculation still based on outer steps for stability\n",
        "    lambda_total_steps = max(1, lambda_epochs * config.training.steps_per_epoch)\n",
        "\n",
        "    # Initialize the persisted MCMC sample store outside the loop\n",
        "    mcmc_samples_persisted: Optional[Dict[str, Any]] = None\n",
        "    optimizer_step_count = 0  # Initialize optimizer step counter\n",
        "\n",
        "    for epoch in range(config.training.num_epochs):\n",
        "        print(f\"\\n=== Starting Epoch {epoch}/{config.training.num_epochs - 1} ===\")\n",
        "        # Split the main key once per epoch\n",
        "        key, key_epoch = jax.random.split(key)  # Update key for next iteration\n",
        "        # Split the epoch key for the different random operations within the epoch\n",
        "        key_time, key_logz, key_samples, key_sampler, key_steps, key_eval = (\n",
        "            jax.random.split(key_epoch, 6)\n",
        "        )\n",
        "\n",
        "        # Use outer step count for lambda factor calculation\n",
        "        global_step_start_epoch = epoch * config.training.steps_per_epoch\n",
        "        epoch_optimizer_step_start = (\n",
        "            optimizer_step_count  # Store optimizer count at epoch start\n",
        "        )\n",
        "\n",
        "        # 1. Calculate Lambda Factor for the epoch (using outer step count)\n",
        "        current_lambda = _compute_lambda_factor(\n",
        "            global_step_start_epoch, lambda_total_steps, lambda_max\n",
        "        )\n",
        "        if not config.offline:\n",
        "            # Log lambda factor with outer step count for consistency with calculation\n",
        "            wandb.log(\n",
        "                {\n",
        "                    \"lambda_factor\": current_lambda,\n",
        "                    \"epoch\": epoch,\n",
        "                    \"global_step_outer\": global_step_start_epoch,\n",
        "                }\n",
        "            )\n",
        "        else:\n",
        "            print(\n",
        "                f\"Epoch {epoch}, Lambda Factor (Outer Step {global_step_start_epoch}): {current_lambda:.4f}\"\n",
        "            )\n",
        "\n",
        "        # 2. Update Time Steps if Continuous\n",
        "        # Note: Previously, ASMC handled time step updates within _maybe_estimate_log_z.\n",
        "        if (\n",
        "            config.integration.continuous_time\n",
        "        ):  # Removed redundant 'and config.mcmc.method != \"asmc\"'\n",
        "            # subkey_epoch, subkey_time = jax.random.split(subkey_epoch) # Removed redundant split\n",
        "            # Changed: Use time_utils.sample_continuous_time\n",
        "            current_ts = sample_continuous_time(\n",
        "                key_time, base_ts\n",
        "            )  # Use key_time\n",
        "            print(f\"Epoch {epoch}: Sampled new continuous time steps.\")\n",
        "\n",
        "        # 3. Estimate Log Z (if needed)\n",
        "        # Note: _maybe_estimate_log_z no longer returns a key\n",
        "        log_Z_t, new_mcmc_samples = _maybe_estimate_log_z(\n",
        "            key_logz,\n",
        "            epoch,\n",
        "            v_theta,\n",
        "            current_ts,\n",
        "            path_distribution,\n",
        "            config,\n",
        "            log_Z_t_ref,\n",
        "            current_lambda,  # Use key_logz\n",
        "            # --- Pass Sharding Specs ---\n",
        "            param_replication=param_replication,\n",
        "            xs_sharding_spec=xs_sharding_spec,  # Pass rank 3 spec for particle sharding\n",
        "            weights_sharding_spec=data_sharding_rank2,  # Pass rank 2 spec for weights sharding\n",
        "        )\n",
        "        # Update the persisted samples only if new ones were actually generated\n",
        "        if new_mcmc_samples is not None:\n",
        "            mcmc_samples_persisted = new_mcmc_samples\n",
        "\n",
        "        # 4. Prepare Samples for Epoch Steps\n",
        "        # Use the persisted samples (potentially from a previous epoch)\n",
        "        # Note: _prepare_epoch_samples no longer returns a key\n",
        "        epoch_training_samples = _prepare_epoch_samples(\n",
        "            key_samples,\n",
        "            v_theta,\n",
        "            current_ts,\n",
        "            path_distribution,\n",
        "            config,\n",
        "            mcmc_samples_persisted,\n",
        "            current_lambda,  # Use key_samples and persisted samples\n",
        "        )\n",
        "\n",
        "        # 5. Run Training Steps for Epoch\n",
        "        # Note: _run_steps_for_epoch now takes and returns optimizer_step_count\n",
        "        # sampler argument removed from the call below\n",
        "        v_theta, opt_state, avg_epoch_loss, optimizer_step_count = _run_steps_for_epoch(\n",
        "            key_steps,\n",
        "            v_theta,\n",
        "            opt_state,\n",
        "            optimizer,\n",
        "            lr_schedule_fn,\n",
        "            epoch,\n",
        "            initial_optimizer_step_count=epoch_optimizer_step_start,  # Pass count at epoch start\n",
        "            samples=epoch_training_samples,\n",
        "            current_ts=current_ts,\n",
        "            log_Z_t=log_Z_t,\n",
        "            path_distribution=path_distribution,\n",
        "            config=config,\n",
        "            # sampler argument removed\n",
        "            # Pass sharding specs down\n",
        "            param_replication=param_replication,\n",
        "            data_sharding_rank1=data_sharding_rank1,\n",
        "            data_sharding_rank2=data_sharding_rank2,\n",
        "            probe_sharding_rank2=probe_sharding_rank2,\n",
        "            probe_sharding_rank3=probe_sharding_rank3,\n",
        "        )\n",
        "\n",
        "        # 6. Calculate and Log Epoch Metrics (Excluding Validation Loss)\n",
        "        # Note: _calculate_and_log_epoch_metrics now takes the final optimizer step count\n",
        "        _calculate_and_log_epoch_metrics(\n",
        "            epoch=epoch,\n",
        "            avg_train_loss=avg_epoch_loss,\n",
        "            final_optimizer_step_count_for_epoch=optimizer_step_count,  # Pass final count\n",
        "            v_theta=v_theta,\n",
        "            path_distribution=path_distribution,\n",
        "            config=config,\n",
        "            lr_schedule_fn=lr_schedule_fn,\n",
        "        )\n",
        "\n",
        "        # 7. Evaluate and Save Model (Periodically)\n",
        "        # key, subkey_epoch = jax.random.split(key) # Removed redundant split\n",
        "        # Note: _maybe_evaluate_and_save no longer returns a key\n",
        "        best_metrics, model_version = _maybe_evaluate_and_save(\n",
        "            key_eval,\n",
        "            epoch,\n",
        "            v_theta,\n",
        "            config,\n",
        "            path_distribution,\n",
        "            target_density,  # Use key_eval\n",
        "            # validation_particles, validation_ts, # Removed\n",
        "            best_metrics,\n",
        "            model_version,\n",
        "            # Pass sharding specs down\n",
        "            param_replication=param_replication,\n",
        "        )\n",
        "\n",
        "        print(f\"=== Finished Epoch {epoch} ===\")\n",
        "\n",
        "    return v_theta, best_metrics  # Return final model and best metrics list"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 146,
      "metadata": {
        "id": "kBmorAHwNzMd"
      },
      "outputs": [],
      "source": [
        "def train_velocity_field(\n",
        "    key: jax.random.PRNGKey,\n",
        "    initial_density,\n",
        "    target_density,\n",
        "    v_theta: PyTree,  # Explicitly PyTree (Equinox module)\n",
        "    config: TrainingExperimentConfig,\n",
        ") -> PyTree:\n",
        "    \"\"\"\n",
        "    Trains a velocity field model (v_theta).\n",
        "\n",
        "    Args:\n",
        "        key: JAX random key.\n",
        "        initial_density: The initial distribution P_0.\n",
        "        target_density: The target distribution P_1.\n",
        "        v_theta: The velocity field model (an Equinox module).\n",
        "        config: The training configuration.\n",
        "\n",
        "    Returns:\n",
        "        The trained velocity field model.\n",
        "    \"\"\"\n",
        "    print(\"--- Starting Velocity Field Training ---\")\n",
        "    print(f\"Config: {config}\")  # Log config at start\n",
        "\n",
        "    # LBFGS + RAD warning removed as LBFGS is handled in lbfgs_core.py\n",
        "\n",
        "    # 1. Initialization\n",
        "    # Removed subkey_init_opt, subkey_init_ts, subkey_init_path from split as they are no longer used here explicitly\n",
        "    # The main loop consumes one key (`subkey_loop`)\n",
        "    key, subkey_loop = jax.random.split(key)  # Split key for the main loop\n",
        "\n",
        "    # --- Device Detection and Sharding Setup ---\n",
        "    devices = jax.devices()\n",
        "    num_devices = len(devices)\n",
        "    print(f\"Detected {num_devices} devices.\")\n",
        "\n",
        "    # Initialize all sharding specs to None\n",
        "    param_replication = None\n",
        "    data_sharding_rank1 = None  # For (batch,)\n",
        "    data_sharding_rank2 = None  # For (batch, dim)\n",
        "    probe_sharding_rank2 = None  # For (batch, dim) probes\n",
        "    probe_sharding_rank3 = None  # For (batch, n_probes, dim) probes\n",
        "    xs_sharding_spec = (\n",
        "        None  # For (batch, n_probes, dim) particles in the second dimension\n",
        "    )\n",
        "\n",
        "    if num_devices > 1:\n",
        "        print(\"Setting up multi-device execution...\")\n",
        "        # Check batch size divisibility (using config.training.batch_size)\n",
        "        if config.training.batch_size % num_devices != 0:\n",
        "            raise ValueError(\n",
        "                f\"Training batch size ({config.training.batch_size}) must be \"\n",
        "                f\"divisible by the number of devices ({num_devices}) for data parallelism.\"\n",
        "            )\n",
        "\n",
        "        # Create device mesh (1D mesh for data parallelism)\n",
        "        device_mesh = mesh_utils.create_device_mesh((num_devices,))\n",
        "        mesh = jshard.Mesh(device_mesh, axis_names=(\"data\",))  # Create the Mesh object\n",
        "        print(f\"Using device mesh: {device_mesh} with axis names ('data',)\")\n",
        "\n",
        "        # --- Define Sharding Specifications ---\n",
        "        # Parameters: Replicated across all devices\n",
        "        param_replication = jshard.NamedSharding(\n",
        "            mesh, jshard.PartitionSpec()\n",
        "        )  # Replicated\n",
        "\n",
        "        # Rank 1 Data: Shard batch dim ('data')\n",
        "        data_sharding_rank1 = jshard.NamedSharding(mesh, jshard.PartitionSpec(\"data\"))\n",
        "\n",
        "        # Rank 2 Data/Probes: Shard batch dim ('data'), replicate feature dim\n",
        "        data_sharding_rank2 = jshard.NamedSharding(\n",
        "            mesh, jshard.PartitionSpec(\"data\", None)\n",
        "        )\n",
        "        probe_sharding_rank2 = data_sharding_rank2  # Same sharding for rank 2 probes\n",
        "\n",
        "        # Rank 3 Probes: Shard batch dim ('data'), replicate probe_count and feature dims\n",
        "        probe_sharding_rank3 = jshard.NamedSharding(\n",
        "            mesh, jshard.PartitionSpec(\"data\", None, None)\n",
        "        )\n",
        "\n",
        "        # Rank 3 Particles: Shard batch dim ('data'), replicate particle_count and feature dims\n",
        "        xs_sharding_spec = jshard.NamedSharding(\n",
        "            mesh, jshard.PartitionSpec(None, \"data\", None)\n",
        "        )\n",
        "\n",
        "        print(\"Sharding specifications created using NamedSharding.\")\n",
        "    else:\n",
        "        print(\"Running on a single device. Sharding specifications remain None.\")\n",
        "        # All sharding specs remain None\n",
        "    # --- End Sharding Setup ---\n",
        "\n",
        "    best_metrics: List[Tuple[float, int]] = []\n",
        "    model_version = 0\n",
        "    log_Z_t_ref: List[Optional[Float[Array, \" time\"]]] = [\n",
        "        None\n",
        "    ]  # Mutable reference for log Z\n",
        "\n",
        "    optimizer, lr_schedule_fn = _setup_optimizer(config)\n",
        "    # Changed: Use time_utils.setup_time_schedule\n",
        "    base_ts = jnp.linspace(0, 1.0, config.sampling.num_timesteps)\n",
        "    path_distribution = _setup_path_distribution(\n",
        "        initial_density, target_density, config\n",
        "    )\n",
        "    opt_state = optimizer.init(eqx.filter(v_theta, eqx.is_inexact_array))\n",
        "\n",
        "    print(\"Initialized Optimizer, Time Steps, Path Distribution.\")\n",
        "    print(\"Validation loss curve generation is disabled.\")\n",
        "\n",
        "    # 2. Run Training Loop\n",
        "    v_theta, best_metrics = _run_training_loop(\n",
        "        key=subkey_loop,  # Use the dedicated key for the loop\n",
        "        v_theta=v_theta,\n",
        "        opt_state=opt_state,\n",
        "        optimizer=optimizer,\n",
        "        lr_schedule_fn=lr_schedule_fn,\n",
        "        path_distribution=path_distribution,\n",
        "        target_density=target_density,\n",
        "        base_ts=base_ts,\n",
        "        # validation_particles=validation_particles, # Removed\n",
        "        # validation_ts=validation_ts_init, # Removed\n",
        "        log_Z_t_ref=log_Z_t_ref,\n",
        "        best_metrics=best_metrics,\n",
        "        model_version=model_version,  # Pass initial version\n",
        "        config=config,\n",
        "        # Pass sharding specs\n",
        "        param_replication=param_replication,\n",
        "        data_sharding_rank1=data_sharding_rank1,\n",
        "        data_sharding_rank2=data_sharding_rank2,\n",
        "        probe_sharding_rank2=probe_sharding_rank2,\n",
        "        probe_sharding_rank3=probe_sharding_rank3,\n",
        "        xs_sharding_spec=xs_sharding_spec,\n",
        "    )\n",
        "\n",
        "    print(\"--- Finished Velocity Field Training ---\")\n",
        "    return v_theta\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fh_eUOrtOPDK"
      },
      "source": [
        "# Train and Eval"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 147,
      "metadata": {
        "id": "vjr2E6mGOL4H"
      },
      "outputs": [],
      "source": [
        "key = jax.random.PRNGKey(888)\n",
        "\n",
        "# Create configuration objects\n",
        "sampling_config = SamplingConfig(\n",
        "    num_particles=2560,\n",
        "    num_timesteps=128,\n",
        ")\n",
        "\n",
        "training_config = TrainingConfig(\n",
        "    num_epochs=2000,\n",
        "    steps_per_epoch=500,\n",
        "    learning_rate=1e-04,\n",
        "    gradient_clip_norm=1.0,\n",
        "    eval_frequency=5,\n",
        "    optimizer=\"adamw\",\n",
        "    weight_decay=1e-06,\n",
        "    beta1=0.9,\n",
        "    beta2=0.999,\n",
        "    time_batch_size=64,\n",
        "    use_shortcut=True,\n",
        "    shortcut_size=[32, 64, 128],\n",
        "    use_combined_loss=True,\n",
        "    shortcut_weight=1.0,\n",
        "    random_alpha=True,\n",
        "    log_z_estimation_frequency=1,\n",
        "    training_data=\"combined\",\n",
        "    batch_size=128,\n",
        ")\n",
        "\n",
        "mcmc_config = MCMCConfig(\n",
        "    method=\"vsmc\",\n",
        "    num_steps=4,\n",
        "    num_integration_steps=5,\n",
        "    step_size=0.1,\n",
        "    use_control_variate=True,\n",
        "    lambda_max=0.1,\n",
        "    lambda_epochs=200,\n",
        "    ess_threshold=0.5,\n",
        ")\n",
        "\n",
        "integration_config = IntegrationConfig(\n",
        "    method=\"euler\",\n",
        "    schedule=\"linear\",\n",
        ")\n",
        "\n",
        "model_config = ModelConfig(\n",
        "    hidden_dim=256,\n",
        "    num_layers=4,\n",
        "    architecture=\"mlp2\",\n",
        "    mlp_depth=4,\n",
        ")\n",
        "\n",
        "density_config = DensityConfig(\n",
        "    target_type=\"gmm\",\n",
        "    initial_sigma=25.,\n",
        "    annealing_path=\"linear\",\n",
        "    input_dim=2,\n",
        "    n_samples_eval=2000,\n",
        ")\n",
        "\n",
        "config = TrainingExperimentConfig(\n",
        "    sampling=sampling_config,\n",
        "    training=training_config,\n",
        "    mcmc=mcmc_config,\n",
        "    integration=integration_config,\n",
        "    model=model_config,\n",
        "    density=density_config,\n",
        "    offline=True,\n",
        "    mixed_precision=False,\n",
        ")\n",
        "\n",
        "# Create and set mixed precision policy\n",
        "mp_policy = jmp.Policy(jnp.float32, jnp.float32, jnp.float32)\n",
        "config.mp_policy = mp_policy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 148,
      "metadata": {
        "id": "84mr1Pf_P2aS"
      },
      "outputs": [],
      "source": [
        "initial_density = MultivariateGaussian(\n",
        "    mean=jnp.zeros(config.density.input_dim),\n",
        "    dim=config.density.input_dim,\n",
        "    sigma=config.density.initial_sigma,\n",
        ")\n",
        "\n",
        "key, subkey = jax.random.split(key)\n",
        "target_density = GMM(\n",
        "    subkey,\n",
        "    dim=config.density.input_dim,\n",
        "    n_samples_eval=config.density.n_samples_eval,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 149,
      "metadata": {
        "id": "Jgaw_LtTQLHA"
      },
      "outputs": [],
      "source": [
        "key, model_key = jax.random.split(key)\n",
        "v_theta = VelocityFieldTwo(\n",
        "    model_key,\n",
        "    dim=config.density.input_dim,\n",
        "    hidden_dim=config.model.hidden_dim,\n",
        "    depth=config.model.num_layers,\n",
        "    shortcut=config.training.use_shortcut,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "8vcQX9KTQRH7",
        "outputId": "dd84409b-0ce5-453f-ca7e-8edeef51f7b3"
      },
      "outputs": [],
      "source": [
        "v_theta = train_velocity_field(\n",
        "    key=key,\n",
        "    initial_density=initial_density,\n",
        "    target_density=target_density,\n",
        "    v_theta=v_theta,\n",
        "    config=config,\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31JGzmjSQYPB"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
