jax==0.4.26
jaxlib==0.4.26
matplotlib==3.9.0
chex==0.1.86
flax==0.8.4
wandb==0.17.3
tqdm==4.66.4
optax==0.2.2
hydra-core==1.3.2
omegaconf==2.3.0
scikit-learn==1.5.1
seaborn==0.13.2
huggingface_hub==0.24.6
networkx==3.3
torch==2.4.1
einshape==1.0