from lapjax.interpreters import ad as ad
from lapjax.interpreters import pxla as pxla
from lapjax.interpreters import partial_eval as partial_eval
from lapjax.interpreters import xla as xla
from lapjax.interpreters import batching as batching
from lapjax.interpreters import mlir as mlir
import sys, importlib
from lapjax.lapsrc.wrapper import _wrap_module
_wrap_module(importlib.import_module(__name__.replace('lapjax', 'jax')), 
             sys.modules[__name__])
