FROM nvcr.io/nvidia/pytorch:22.05-py3
# update conda, install mamba so everything else will be faster, then install rdkit and other conda dependencies, then install our package
RUN conda update -n base -y conda && conda install -n base -c conda-forge -y mamba #&&
RUN conda install -n base -c conda-forge graph-tool #&&
RUN mamba  install -y rdkit #&& 
RUN pip install pyyaml && pip install overrides imageio numpy scipy tqdm wandb hydra-core seaborn #&&
RUN pip install pytorch_lightning  torchmetrics torch==1.11 torchvision  --extra-index-url https://download.pytorch.org/whl/cu113 #&&
RUN pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-1.11.0+cu113.html
#RUN pip install pyyaml && pip install overrides imageio numpy scipy tqdm wandb hydra-core #&&
ADD dgd/analysis /workspace/analysis
ADD dgd/configs /workspace/configs
ADD dgd/diffusion /workspace/diffusion
ADD dgd/ggg_data /workspace/ggg_data
ADD dgd/ggg_metrics /workspace/ggg_metrics
ADD dgd/models /workspace/models
ADD ./data/qm9/qm9_pyg /workspace/data/qm9/qm9_pyg
ADD *.py /workspace/
RUN pip install -e .
# expects you to to run the image with `docker run -e WANDB_API_KEY=$YOUR_API_KEY graphgendiff:latest`
CMD python main.py
