FROM nvidia/cuda:11.0-cudnn8-devel-ubuntu18.04

ENV DEBIAN_FRONTEND noninteractive

RUN apt-get update \
    && apt-get install -y \
                       build-essential \
                       ca-certificates \
                       wget \
                       unzip \
                       ssh \
                       cmake \
                       git \
                       vim \
                       python3-dev python3-pip python3-setuptools

RUN ln -sf $(which python3) /usr/bin/python \
    && ln -sf $(which pip3) /usr/bin/pip

ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

WORKDIR /spe
COPY . /spe

RUN python -m pip install --upgrade pip

WORKDIR /spe/lra/long-range-arena
RUN pip install -e .
WORKDIR /spe/lra/fast_attention
RUN pip install -e .
WORKDIR /spe
RUN pip install -e src/jax

WORKDIR /spe/positional-bias
RUN pip install -e .

WORKDIR /spe

RUN pip install --upgrade jaxlib==0.1.68+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html
RUN pip install -r lra/requirements.txt
