FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-runtime

# declare the image name
ENV IMG_NAME=model_risk
    # declare what jaxlib tag to use
    # if a CI/CD system is expected to pass in these arguments
    # the dockerfile should be modified accordingly

# install python3-pip
RUN apt update && apt install python3-pip -y && apt install git -y

# install dependencies via pip

ENV XLA_PYTHON_CLIENT_PREALLOCATE false
RUN pip install "jax[cuda12]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html


RUN pip install flax && pip install optax && pip install pandas==2.0.3 && pip install orbax \
    && pip install matplotlib && pip install gymnasium && pip install gym==0.23 && pip install stable-baselines3 --no-deps  \
    && pip install fire && pip install msgpack-rpc-python&& pip install airsim && pip install distrax && pip install pathos && pip install plotly



# install gym box2d
RUN apt install swig -y && pip install gym[box2d]

# install mujoco

# ARG DEBIAN_FRONTEND=noninteractive
# RUN apt install wget -y &&   apt install libglfw3 -y &&  apt install libglfw3-dev -y &&  apt install  libgl1-mesa-dev \
#     libgl1-mesa-glx \
#    libglew-dev \
#    libosmesa6-dev\0
#    gcc \
#    patchelf -y &&  pip install imageio &&  pip install mujoco-py && pip install mujoco==2.2.0

#RUN mkdir -p /root/.mujoco \
#    && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \
#    && tar -xf mujoco.tar.gz -C /root/.mujoco \
#    && rm mujoco.tar.gz && pip install "Cython < 1"

# ENV LD_LIBRARY_PATH /root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
# ENV LD_LIBRARY_PATH /usr/local/nvidia/lib64:${LD_LIBRARY_PATH}
# compile
# RUN echo "import gym; gym.make('Hopper-v3'); print('compiled')" | python3 &&  apt-get clean

