import jax
import jax.numpy as jnp
import numpy as np

if __name__ == "__main__":

    # # 1. extract the latest tour length max
    # latest_tour_length_max = jnp.array([1.1, 1.2, 1., 1.1, 1.3, 1.2, 1.4, 1.5, 1.6, 1.2])

    # # get a triangular superior matrix with 1 and 0
    # mask = jnp.tril(jnp.ones((latest_tour_length_max.shape[0], latest_tour_length_max.shape[0])))

    # # multiply the mask with the latest tour length max
    # masked_latest_tour_length_max = latest_tour_length_max * mask

    # # get the max of the masked latest tour length max
    # max_masked_latest_tour_length_max = jnp.max(masked_latest_tour_length_max, axis=1)

    # # print the results proudly
    # print("1. ", latest_tour_length_max)
    # print("2. ", mask)
    # print("3. ", masked_latest_tour_length_max)
    # print("4. ", max_masked_latest_tour_length_max)

    # # another way to do it is to simply use jax.lax.cummax

    # result = jax.lax.cummax(latest_tour_length_max, axis=0)

    # print("5. ", result)

    a = jnp.ones((3, 8))
    b = jnp.arange(8, 0, -1)

    c = a / b

    print(c)
