base_imports = """
import jax
import jax.numpy as jnp
"""
