# Start from NVIDIA PyTorch base
FROM nvcr.io/nvidia/pytorch:23.10-py3

# Install system packages (e.g. build-essential)
RUN apt-get update && apt-get install -y \
    build-essential \
 && rm -rf /var/lib/apt/lists/*

# Upgrade pip (optional but often recommended)
RUN pip install --upgrade pip

# Install Transformer Engine from GitHub (release_v1.12 branch)
RUN pip install git+https://github.com/NVIDIA/TransformerEngine.git@release_v1.12

RUN pip install wandb

RUN pip install flash-attn==2.6.3
# (Optional) Install any other Python packages
# RUN pip install some-other-package

# Default command (optional)
CMD ["/bin/bash"]
