import jax
import jax.numpy as jnp

if __name__ == "__main__":
    # create a first matrix of shape (8,) with zeros
    a = jnp.array([1, 2, 3, 4, 5, 6, 7, 8])
    mask = jnp.array([1, 1, 1, 1, 1, 1, 1, 1]) * 0.0

    result = jnp.max(a, where=mask, initial=0)

    print(result)
