FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-devel

ENV FORCE_CUDA="1"
ARG TORCH_CUDA_ARCH_LIST="6.1;7.5"

RUN apt-get update
RUN pip install lightning tonic pudb -U urllib3 neptune matplotlib jupyterlab scikit-learn seaborn
RUN pip install --no-cache-dir torch-scatter==1.3.2