jax
playground==0.0.5
matplotlib==3.9.2
jax_dataclasses>=1.6.3
wandb==0.21.0
tyro>=1.0.0
