name: swg_jax
channels:
  - defaults
  - conda-forge

dependencies:
  - python==3.9
  - pip==24.0
  - numpy<=1.26.4, >=1.20.2
  - scipy>=1.4.1
  - tqdm=4.64.0
  - imageio=2.19.1
  - imageio-ffmpeg=0.4.7
  - cython==0.29.30
  - Pillow==8.4.0
  - hydra-core==1.3.2
  - wandb
  - matplotlib
  - conda-build
  - pip:
      - gym >= 0.21.0 , < 0.24.0 
      - "jax[cuda12]"
      - optax==0.1.5
      - flax
      - d4rl @ git+https://github.com/Farama-Foundation/d4rl@master#egg=d4rl 
      - mujoco<=3.1.6
      - dm_control<=1.0.20  #https://github.com/Farama-Foundation/D4RL/issues/236