from lapjax.experimental.jax2tf import call_tf as call_tf
from lapjax.experimental.jax2tf import impl_no_xla as impl_no_xla
from lapjax.experimental.jax2tf import shape_poly as shape_poly
from lapjax.experimental.jax2tf import jax2tf as jax2tf
import sys, importlib
from lapjax.lapsrc.wrapper import _wrap_module
_wrap_module(importlib.import_module(__name__.replace('lapjax', 'jax')), 
             sys.modules[__name__])
