#!/bin/bash

USE_MEGATRON=${USE_MEGATRON:-1}
USE_SGLANG=${USE_SGLANG:-1}

export MAX_JOBS=32

echo "1. install inference frameworks and pytorch they need"
if [ $USE_SGLANG -eq 1 ]; then
    pip install "sglang[all]==0.4.6.post1" --no-cache-dir --find-links https://flashinfer.ai/whl/cu124/torch2.6/flashinfer-python && pip install torch-memory-saver --no-cache-dir
fi
pip install --no-cache-dir "vllm==0.8.5.post1" "torch==2.6.0" "torchvision==0.21.0" "torchaudio==2.6.0" "tensordict==0.6.2" torchdata 

echo "2. install basic packages"
pip install "transformers[hf_xet]>=4.51.0" accelerate datasets peft hf-transfer \
    "numpy<2.0.0" "pyarrow>=15.0.0" pandas \
    ray[default] codetiming hydra-core pylatexenc qwen-vl-utils wandb dill pybind11 liger-kernel mathruler \
    pytest py-spy pyext pre-commit ruff

pip install "nvidia-ml-py>=12.560.30" "fastapi[standard]>=0.115.0" "optree>=0.13.0" "pydantic>=2.9" "grpcio>=1.62.1"

# pip install flash_attn==2.7.4.post1 --no-build-isolation --no-cache-dir
# pip install flashinfer-python==0.2.2.post1 --no-cache-dir

# echo "3. install FlashAttention and FlashInfer"
# # Install flash-attn-2.7.4.post1 (cxx11abi=False)
# wget -nv https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl && \
#     pip install --no-cache-dir flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

pip install "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl"

# Install flashinfer-0.2.2.post1+cu124 (cxx11abi=False)
# vllm-0.8.3 does not support flashinfer>=0.2.3
# see https://github.com/vllm-project/vllm/pull/15777
# wget -nv https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl && \
#     pip install --no-cache-dir flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl

pip install "flashinfer-python @ https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.2.post1/flashinfer_python-0.2.2.post1+cu124torch2.6-cp38-abi3-linux_x86_64.whl"

if [ $USE_MEGATRON -eq 1 ]; then
    echo "4. install build dependencies for TransformerEngine"
    pip install cmake ninja --no-cache-dir
    echo "5. Install cudnn python package (required for TransformerEngine build)"
    # pip install nvidia-cudnn-cu12==9.8.0.87 --no-cache-dir
    echo "6. install TransformerEngine and Megatron"
    echo "Notice that TransformerEngine installation can take very long time, please be patient"
    # Set cuDNN paths for CMake to find
    CUDNN_BASE_DIR=$(python -c "import site; import os; cudnn_path = [os.path.join(p, 'nvidia', 'cudnn') for p in site.getsitepackages() if os.path.exists(os.path.join(p, 'nvidia', 'cudnn', 'include', 'cudnn.h'))]; print(cudnn_path[0] if cudnn_path else '')")
    if [ -n "$CUDNN_BASE_DIR" ]; then
        export CUDNN_INCLUDE_DIR="$CUDNN_BASE_DIR/include"
        export CUDNN_LIBRARY_DIR="$CUDNN_BASE_DIR/lib"
        export CMAKE_PREFIX_PATH="$CUDNN_BASE_DIR:$CMAKE_PREFIX_PATH"
        export LD_LIBRARY_PATH="$CUDNN_BASE_DIR/lib:$LD_LIBRARY_PATH"
        echo "Found cuDNN at: $CUDNN_BASE_DIR"
        echo "Setting CUDNN_INCLUDE_DIR=$CUDNN_INCLUDE_DIR"
        echo "Setting CUDNN_LIBRARY_DIR=$CUDNN_LIBRARY_DIR"
    else
        echo "Warning: Could not find cuDNN installation"
    fi
    NVTE_FRAMEWORK=pytorch pip3 install --no-deps --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@v2.2 --no-cache-dir
    pip3 install --no-deps git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.12.0rc3 --no-cache-dir
fi


echo "5. May need to fix opencv"
pip install opencv-python
pip install opencv-fixer && \
    python -c "from opencv_fixer import AutoFix; AutoFix()"


if [ $USE_MEGATRON -eq 1 ]; then
    echo "6. Install cudnn python package (avoid being overridden)"
    pip install nvidia-cudnn-cu12==9.8.0.87
fi

echo "Successfully installed all packages"
