name: osds
channels:
  - conda-forge
dependencies:
  - python=3.10
  - pip>=23.3
  - cmake    
  - ninja
  - setuptools
  - wheel
  - pip:
      - --find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
      - jaxlib==0.4.28+cuda12.cudnn89
      - jax==0.4.28
      - nvidia-cudnn-cu12==8.9.7.29
      - numpyro==0.15.0
      - flax==0.8.3
      - optax==0.2.4
      - dm-haiku==0.0.12
      - distrax==0.1.5
      - blackjax==1.0.0
      - ott-jax==0.4.6
      - hydra-core==1.3.2
      - hydra-joblib-launcher==1.2.0
      - hydra-submitit-launcher==1.2.0
      - wandb==0.15.10
      - ml-collections==0.1.1
      - tqdm==4.66.1
      - matplotlib==3.8.4
      - pandas==2.2.2
      - inference-gym==0.0.4
      - tensorflow==2.16.1
      - tf-keras==2.16.0
      - tensorflow-probability==0.24.0
      - orbax-checkpoint==0.5.4
      - etils==1.7.0