#!/bin/bash -x


conda create --name cpr python=3.10 -y
conda activate cpr
pip install --upgrade pip setuptools wheel
pip install tqdm pyyaml pyaml numpy scipy pandas scikit-learn matplotlib
conda install pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia -y
pip install triton==2.0.0.dev20221202 ninja
pip install tensorboard==2.13.0 transformers==4.30.2 datasets==2.13.1 pytorch-lightning==2.0.4
pip install nnunetv2==2.2
pip install flash-attn==2.0.0.post1 --no-build-isolation
pip install -e .

git clone https://github.com/HazyResearch/flash-attention \
    && git config --global --add safe.directory /home/user/flash-attention \
    && cd flash-attention && git checkout v2.0.0 \
    && cd csrc/fused_softmax && pip install . && cd ../../ \
    && cd csrc/rotary && pip install . && cd ../../ \
    && cd csrc/xentropy && pip install . && cd ../../ \
    && cd csrc/layer_norm && pip install . && cd ../../ \
    && cd csrc/fused_dense_lib && pip install . && cd ../../ \
    && cd csrc/ft_attention && pip install . && cd ../../ \
    && cd .. && rm -rf flash-attention