import jax

jax.config.update("jax_default_matmul_precision", "highest")
