name: JaxCQL
channels:
    - defaults
    - nvidia
    - conda-forge
dependencies:
    - python=3.8
    - pip
    - numpy
    - scipy
    - h5py
    - matplotlib
    - scikit-learn
    - jupyter
    - tqdm
    - seaborn
    - Cython
    - jax=0.3.16
    - jaxlib=0.3.15=*cuda*
    - cudatoolkit=11.3
    - cuda-nvcc=11.3
    - cudnn
    - pip:
        - flax==0.6.0
        - optax==0.1.3
        - distrax==0.1.2
        - gym
        - absl-py
        - git+https://github.com/Farama-Foundation/D4RL.git
        - wandb
        - ml_collections