import os
# os.environ.setdefault("CUDA_VISIBLE_DEVICES", "1")
# os.environ.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
# os.environ.setdefault("XLA_PYTHON_CLIENT_MEM_FRACTION", "0.85")
# os.environ.setdefault("XLA_FLAGS", "--xla_gpu_enable_command_buffer=")

import jax
from trainer import make_train
from config import default_cfg

if __name__ == "__main__":
    cfg = default_cfg()
    rng = jax.random.PRNGKey(0)
    train = make_train(cfg)
    out = train(rng)