import jax

JIT = True
ENZYME = False

if ENZYME:
    import enzyme_ad.jax as ejax
    ENZYME_OPTS = "inline{default-pipeline=canonicalize max-iterations=4}," + ejax.primitives.hlo_opts()
    enzyme_fn = ejax.enzyme_jax_ir(pipeline_options=ejax.JaXPipeline(ENZYME_OPTS))

def maybe_jit(fn, off=False, **kwargs):
    if JIT and not off:
        if ENZYME:
            ffn = enzyme_fn(fn)
        else:
            ffn = fn

        return jax.jit(ffn, **kwargs)
    else:
        return fn

def set_jit(jit):
    global JIT
    JIT = jit