# Use the official PyTorch image with CUDA 12.1
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-devel

# Set working directory
WORKDIR ../workspace

RUN apt-get update && apt-get install -y \
    git \
 && rm -rf /var/lib/apt/lists/*

# Copy requirements into the image
COPY requirements.txt .

# Upgrade pip and install dependencies
RUN pip install --upgrade pip \
 && pip install --no-cache-dir -r requirements.txt

RUN pip install causal-conv1d==1.5.0.post8 flash-attn==2.7.3 mamba-ssm==2.2.2

# Default command
CMD ["/bin/bash"]
