import jax

@jax.jit
def f(x):
    print('Compiling f')
    return x

def double_f(x, y):
    ff = jax.jit(f)
    ff(x)
    ff(y)        