import jax
import jax.numpy as jnp

if __name__ == "__main__":
    
    # define a variable
    a = 10

    # define a function that uses that variable
    @jax.jit
    def f(x):
        return a * x

    # update the value of that variable
    a = 20

    # call the function
    print(f(2))

    # the result should be 20 * 2 = 40

    # ---------------

    random_key = jax.random.PRNGKey(0)

    print("1. ", random_key.shape)
    print("1. ", random_key)

    random_key, subkey = jax.random.split(random_key)

    print("2. ", random_key.shape)
    print("2. ", random_key)
    print("2. ", subkey.shape)
    print("2. ", subkey)

    random_key = jax.random.split(random_key)

    print("3. ", random_key.shape)
    print("3. ", random_key)

    random_key = jax.random.split(random_key)

    print("4. ", random_key.shape)
    print("4. ", random_key)

    subkeys = jax.random.split(subkey, 16).reshape((4, 4, -1))

    print("5. ", subkeys.shape)
    print("5. ", subkeys)

    new_subkeys1 = jax.vmap(jax.vmap(jax.random.split))(subkeys)
    new_subkeys2 = jax.vmap(jax.vmap(lambda x: jax.random.split(x)[1]))(subkeys)

    print("6. ", new_subkeys1.shape)
    print("6. ", new_subkeys1)

    print("7. ", new_subkeys2.shape)
    print("7. ", new_subkeys2)
