
import jax
import jax.numpy as jnp
from jax import lax

# 示例变量，实际应用中你的 traced_val1 和 traced_val2 是通过计算得到的
traced_val1 = jax.numpy.array(1)
traced_val2 = jax.numpy.array(1)

@jax.jit
def compare_values(val1, val2):
    # 使用 jnp.equal 进行数值比较
    comparison = jnp.equal(val1, val2)
    # 使用 lax.cond 处理布尔转换
    return lax.cond(comparison, lambda _: True, lambda _: False, operand=None)

# 调用 JIT 编译的函数进行检查
result = compare_values(traced_val1, traced_val2)

# 将结果从设备上同步到主机
result = jax.device_get(result)

assert result, "The values are not equal"

print("The values are equal")