import jax
def compute_pytree_size(pytree):
    total_size = 0
    for array in jax.tree_util.tree_leaves(pytree):
        total_size += array.size * array.dtype.itemsize
    in_mb = total_size / (1024 ** 2) 
    print(f"model size in megabytes: {in_mb:.3f}")


