flax==0.6.4
jax==0.4.1
matplotlib==3.6.3
numpy==1.24.3
optax==0.1.4
Pillow==10.0.1
scikit_learn==1.2.1
scipy==1.11.3
torchvision==0.15.2
typer==0.9.0
