# FROM pytorch/pytorch:2.2.2-cuda12.1-cudnn8-devel
FROM pytorch/pytorch:2.4.1-cuda12.1-cudnn9-devel
# FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-runtime
RUN pip install --upgrade pip \
&& pip install pandas tqdm jupyter jupyterlab ipython lightning pymatgen scipy scikit-learn transformers accelerate ConfigArgParse datasets tokenizers torchtext mlflow seqeval tensorboard mysql-connector-python hydra-core mp-api opentsne
RUN conda install pyg -c pyg -y
RUN conda install -c pytorch -c nvidia faiss-gpu=1.8.0