#!/bin/bash
set -a

# Default values
MODEL_PATH="Qwen/Qwen2.5-32B-Instruct"
SERVER_PORT="30000"
SERVER_HOST="0.0.0.0"

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        -m|--model)
            MODEL_PATH="$2"
            shift 2
            ;;
        --port)
            SERVER_PORT="$2"
            shift 2
            ;;
        --host)
            SERVER_HOST="$2"
            shift 2
            ;;
        *)
            echo "Unknown option: $1"
            echo "Usage: $0 [OPTIONS]"
            echo "  -m, --model MODEL_PATH    Model path (default: Qwen/Qwen2.5-32B-Instruct)"
            echo "  --port PORT               Server port (default: 30000)"
            echo "  --host HOST               Server host (default: 0.0.0.0)"
            exit 1
            ;;
    esac
done

# Initialize conda if not already done
if ! command -v conda &> /dev/null; then
    echo "Error: conda is not available"
    exit 1
fi

# Initialize conda for bash shell
eval "$(conda shell.bash hook)"

# HF_TOKEN and HF_HOME should be set as environment variables before running this script
if [ -z "$HF_HOME" ]; then
    echo "Error: HF_HOME environment variable is not set"
    echo "Please set it before running this script:"
    echo "  export HF_HOME=\$SCRATCH/LLMs"
    exit 1
fi

# Activate conda environment
conda activate sglang-env

# Verify we're using the right Python
if ! python3 -c "import sglang" 2>/dev/null; then
    echo "Error: sglang is not installed in sglang-env environment"
    echo "Please install it: conda activate sglang-env && uv pip install sglang --prerelease=allow"
    exit 1
fi

# Load module
module load gcc/11.2.0

# Set up flashinfer workspace
export FLASHINFER_WORKSPACE_BASE=${TMPDIR:-/tmp}/flashinfer_cache_$$
mkdir -p $FLASHINFER_WORKSPACE_BASE

# Detect number of GPUs and set tensor parallelism accordingly
if command -v nvidia-smi &> /dev/null; then
    NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
    echo "Detected $NUM_GPUS GPU(s), setting tensor parallelism to $NUM_GPUS"
else
    NUM_GPUS=1
    echo "nvidia-smi not found, defaulting to 1 GPU"
fi

# Run SGLang server using the conda environment's Python
echo "Starting SGLang server with model: $MODEL_PATH"
echo "  Host: $SERVER_HOST"
echo "  Port: $SERVER_PORT"
python3 -m sglang.launch_server --model-path "$MODEL_PATH" --host "$SERVER_HOST" --log-level warning --tp "$NUM_GPUS" --port "$SERVER_PORT"