#!/usr/bin/env bash
# set -e

ROOT="src/VILA"
export PYTHONPATH=${PYTHONPATH}:${ROOT}
cd ${ROOT}

CONDA_ENV=${1:-""}
if [ -n "$CONDA_ENV" ]; then
    # This is required to activate conda environment
    eval "$(conda shell.bash hook)"

    conda create -n $CONDA_ENV python=3.10.14 -y
    conda activate $CONDA_ENV
    # This is optional if you prefer to use built-in nvcc
    conda install -c nvidia cuda-toolkit -y
    
    # Install gcc and g++ compilers
    conda install -y gcc_linux-64 gxx_linux-64
    ln -s $CONDA_PREFIX/bin/x86_64-conda_cos7-linux-gnu-gcc $CONDA_PREFIX/bin/gcc
    ln -s $CONDA_PREFIX/bin/x86_64-conda_cos7-linux-gnu-g++ $CONDA_PREFIX/bin/g++
else
    echo "Skipping conda environment creation. Make sure you have the correct environment activated."
fi

# This is required to enable PEP 660 support
pip install --upgrade pip setuptools

pip install venus-tools

# Install FlashAttention2
pip install https://github.com/Dao-AILab/flash-attention/releases/download/v2.5.8/flash_attn-2.5.8+cu122torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

# Install VILA
pip install -e ".[train,eval]"

# Quantization requires the newest triton version, and introduce dependency issue
pip install triton==3.2.0

# numpy introduce a lot dependencies issues, separate from pyproject.yaml
# pip install numpy==1.26.4

# Replace transformers and deepspeed files
site_pkg_path=$(python -c 'import site; print(site.getsitepackages()[0])')
cp -rv ./llava/train/deepspeed_replace/* $site_pkg_path/deepspeed/

# Downgrade protobuf to 3.20 for backward compatibility
pip install protobuf==3.20.*

pip install tensorboard
pip install diffusers==0.33.1
pip install imageio-ffmpeg
pip install opencv-python
pip install omegaconf
pip install transformers==4.46.0
pip install peft

pip3 install -U torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
pip3 install torch==2.6.0 xformers

CONDA_ENV_PATH=$(conda info --base)/envs/$CONDA_ENV
export CPATH=$CONDA_ENV_PATH/targets/x86_64-linux/include:$CPATH

DS_BUILD_CPU_ADAM=1 pip install --no-cache-dir deepspeed==0.9.5 --global-option="build_ext" --global-option="-j8"

pip install --upgrade bitsandbytes

if ! command -v ffmpeg &> /dev/null; then
    conda install -y ffmpeg
fi

pip install imageio[ffmpeg]