#!/bin/bash
# setup_vllm012_venv.sh
#
# Create a fresh Python virtual environment with:
#   - PyTorch 2.9.0 + CUDA 12.9 (cu129)
#   - vLLM 0.12.0
#   - VERL 0.7.0
#   - Training deps (pandas, datasets, peft, ray, etc.)
#   - vLLM LoRA PDL patch for SM100a (B200)
#
# Mirrors the environment in verlai/verl:vllm012.latest Docker image.
# Used for B200 testing to diagnose the entropy gap / high grad_norm issue.
#
# Usage:
#   bash scripts/setup_vllm012_venv.sh [venv_path]
#   bash scripts/setup_vllm012_venv.sh $HOME/verl-vllm012
#
# Default venv path: $HOME/verl-vllm012/

set -e

# ============================================
# Configuration
# ============================================

VENV_PATH="${1:-${VLLM_VENV_PATH:-$HOME/verl-vllm012}}"

# Expand ~ in path
VENV_PATH="${VENV_PATH/#\~/$HOME}"

echo "=========================================="
echo "vLLM 0.12.0 + PyTorch 2.9.0 venv setup"
echo "=========================================="
echo "Venv path: $VENV_PATH"
echo "CUDA target: 12.9 (cu129)"
echo ""

# ============================================
# Verify python3.12 is available
# ============================================

echo "Checking for Python 3.12..."
# Check multiple locations: system python3.12, uv-installed, verl-latest venv
if command -v python3.12 &>/dev/null; then
    PYTHON_BIN="python3.12"
elif [ -f "$HOME/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/bin/python3.12" ]; then
    PYTHON_BIN="$HOME/.local/share/uv/python/cpython-3.12.10-linux-x86_64-gnu/bin/python3.12"
elif [ -f "$HOME/verl-latest/bin/python3" ]; then
    # Use the same Python binary as verl-latest (may be 3.12 from uv)
    REAL_PY=$(readlink -f "$HOME/verl-latest/bin/python3")
    PY_VER=$("$REAL_PY" --version 2>&1 | awk '{print $2}')
    PY_MAJOR_MINOR=$(echo "$PY_VER" | cut -d. -f1-2)
    if [ "$PY_MAJOR_MINOR" = "3.12" ] || [ "$PY_MAJOR_MINOR" = "3.11" ] || [ "$PY_MAJOR_MINOR" = "3.10" ]; then
        PYTHON_BIN="$REAL_PY"
    else
        echo "ERROR: verl-latest Python is $PY_VER, need >= 3.10"
        exit 1
    fi
else
    echo "ERROR: No suitable Python found. Need Python 3.10+."
    exit 1
fi

PYTHON_VERSION=$($PYTHON_BIN --version)
echo "Found: $PYTHON_VERSION at $PYTHON_BIN"

# ============================================
# Create virtual environment
# ============================================

if [ -d "$VENV_PATH" ]; then
    echo ""
    echo "WARNING: Venv already exists at $VENV_PATH"
    echo "Delete it first to start fresh, or press Ctrl+C to abort."
    echo "Continuing with existing venv (packages will be added/upgraded)..."
    echo ""
else
    echo ""
    echo "Creating virtual environment at $VENV_PATH ..."
    "$PYTHON_BIN" -m venv "$VENV_PATH"
    echo "Venv created."
fi

# ============================================
# Activate venv
# ============================================

echo ""
echo "Activating venv..."
source "$VENV_PATH/bin/activate"
echo "Active Python: $(which python3)"
echo "Python version: $(python3 --version)"

# ============================================
# Upgrade pip/setuptools/wheel first
# ============================================

echo ""
echo "Upgrading pip, setuptools, wheel..."
pip install --upgrade pip setuptools wheel

# ============================================
# Install PyTorch 2.9.0 with CUDA 12.9
# ============================================

echo ""
echo "Installing PyTorch 2.9.0 (cu129)..."
pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/cu129

echo "PyTorch installed: $(python3 -c 'import torch; print(torch.__version__)')"

# ============================================
# Install vLLM 0.12.0
# ============================================

echo ""
echo "Installing vLLM 0.12.0..."
pip install vllm==0.12.0

echo "vLLM installed: $(python3 -c 'import vllm; print(vllm.__version__)')"

# ============================================
# Install VERL 0.7.0
# ============================================

echo ""
echo "Installing VERL 0.7.0..."
pip install verl==0.7.0

echo "VERL installed: $(python3 -c 'import verl; print(verl.__version__)' 2>/dev/null || echo '(version attr not available)')"

# ============================================
# Install training dependencies
# ============================================

echo ""
echo "Installing training dependencies..."
pip install \
    pandas \
    pyarrow \
    datasets \
    transformers \
    peft \
    accelerate \
    wandb \
    "ray[default]" \
    hydra-core

echo "Training deps installed."

# ============================================
# Apply vLLM LoRA PDL patch for SM100a (B200)
# ============================================
# Disables PDL (Programmatic Dependent Launch) for compute capability >= 100.
# PDL gdc_wait() calls cause illegal instruction crashes on SM100a.
# Upstream fix: https://github.com/vllm-project/vllm/issues/30872

echo ""
echo "Applying vLLM LoRA PDL patch for SM100a (B200)..."

UTILS_FILE=$(python3 -c "import vllm; import os; print(os.path.join(os.path.dirname(vllm.__file__), 'lora/ops/triton_ops/utils.py'))" 2>/dev/null)

if [ -n "$UTILS_FILE" ] && [ -f "$UTILS_FILE" ]; then
    if grep -q "has_device_capability(90)" "$UTILS_FILE" && ! grep -q "has_device_capability(100)" "$UTILS_FILE"; then
        echo "  File: $UTILS_FILE"
        echo "  Patching: disabling PDL for SM100a (B200)..."
        sed -i 's/return current_platform.is_cuda() and current_platform.has_device_capability(90)/return current_platform.is_cuda() and current_platform.has_device_capability(90) and not current_platform.has_device_capability(100)/' "$UTILS_FILE"
        echo "  Patch applied. LoRA Triton kernels will compile without gdc_wait() on B200."
    else
        echo "  PDL patch already applied or not needed (skipping)."
    fi
else
    echo "  WARNING: Could not locate vllm/lora/ops/triton_ops/utils.py"
    echo "  File path: '${UTILS_FILE}'"
    echo "  PDL patch NOT applied — B200 LoRA may crash."
fi

# ============================================
# Install flash-attn (FA2)
# ============================================

echo ""
echo "Installing flash-attn (FA2) — required for B200 FLASH_ATTN backend..."
pip install flash-attn --no-build-isolation
echo "flash-attn installed."

# ============================================
# Version summary
# ============================================

echo ""
echo "=========================================="
echo "Version Summary"
echo "=========================================="

python3 -c "
import torch
import vllm
import sys

print(f'Python:  {sys.version.split()[0]}')
print(f'PyTorch: {torch.__version__}')
print(f'CUDA (torch): {torch.version.cuda}')

try:
    import vllm
    print(f'vLLM:    {vllm.__version__}')
except Exception as e:
    print(f'vLLM:    ERROR - {e}')

try:
    import verl
    print(f'VERL:    {verl.__version__}')
except AttributeError:
    print('VERL:    installed (no __version__ attr)')
except ImportError as e:
    print(f'VERL:    ERROR - {e}')

try:
    import flash_attn
    print(f'flash-attn: {flash_attn.__version__}')
except ImportError:
    print('flash-attn: not installed')

if torch.cuda.is_available():
    print(f'GPU:     {torch.cuda.get_device_name(0)}')
    cap = torch.cuda.get_device_capability(0)
    print(f'SM:      sm{cap[0]}{cap[1]}a')
else:
    print('GPU:     not available (CUDA not detected at summary time)')
"

echo ""
echo "=========================================="
echo "Setup complete!"
echo ""
echo "To activate this venv:"
echo "  source $VENV_PATH/bin/activate"
echo ""
echo "Required env vars for B200 training:"
echo "  export VLLM_USE_V1=1"
echo "  export VLLM_USE_TRTLLM_ATTENTION=0"
echo "  export VLLM_ATTENTION_BACKEND=FLASH_ATTN"
echo "  unset PYTORCH_CUDA_ALLOC_CONF"
echo "  export WANDB_CONSOLE=off"
echo "  export PYTHONUNBUFFERED=1"
echo "=========================================="
