#!/bin/bash
set -a

# Default values
MODEL_PATH="${SCRATCH}/increase_rewriter/dpo"
SERVER_PORT="8000"
SERVER_HOST="0.0.0.0"
SERVED_MODEL_NAME="increase"
DTYPE="bfloat16"

# NOTE: vLLM supports LoRA modules via --enable-lora and --lora-modules.
# We wire this so that if a LoRA path is provided, we register it under the
# served model name (default: "increase"), which makes requests with
# model="increase" automatically use the LoRA adapter.
LORA_PATH=""

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        -m|--model)
            MODEL_PATH="$2"
            shift 2
            ;;
        -l|--lora-path|--lora)
            LORA_PATH="$2"
            shift 2
            ;;
        --served-model-name)
            SERVED_MODEL_NAME="$2"
            shift 2
            ;;
        --port)
            SERVER_PORT="$2"
            shift 2
            ;;
        --host)
            SERVER_HOST="$2"
            shift 2
            ;;
        --dtype)
            DTYPE="$2"
            shift 2
            ;;
        *)
            echo "Unknown option: $1"
            echo "Usage: $0 [OPTIONS]"
            echo "  -m, --model MODEL_PATH          Model path/tag (default: $MODEL_PATH)"
            echo "  -l, --lora-path PATH            (optional) LoRA path"
            echo "  --served-model-name NAME        Served model name (default: $SERVED_MODEL_NAME)"
            echo "  --port PORT                     Server port (default: $SERVER_PORT)"
            echo "  --host HOST                     Server host (default: $SERVER_HOST)"
            echo "  --dtype DTYPE                   vLLM dtype (default: $DTYPE)"
            exit 1
            ;;
    esac
done

if ! command -v vllm &> /dev/null; then
    echo "Error: vllm is not available in PATH"
    exit 1
fi

# Detect GPU count for tensor parallel
if command -v nvidia-smi &> /dev/null; then
    NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
    echo "Detected $NUM_GPUS GPU(s), setting tensor parallelism to $NUM_GPUS"
else
    NUM_GPUS=1
    echo "nvidia-smi not found, defaulting to 1 GPU"
fi

# Resolve LoRA checkpoint if a directory is provided
FINAL_LORA_PATH=""
if [ -n "$LORA_PATH" ]; then
    if [ -d "$LORA_PATH" ]; then
        if [[ "$LORA_PATH" == */checkpoint-* ]]; then
            FINAL_LORA_PATH="$LORA_PATH"
        else
            TRAINER_STATE_JSON="$LORA_PATH/trainer_state.json"
            if [ -f "$TRAINER_STATE_JSON" ]; then
                GS=$(python3 -c "import json; print(json.load(open('$TRAINER_STATE_JSON','r')).get('global_step',''))" 2>/dev/null || echo "")
                if [[ "$GS" =~ ^[0-9]+$ ]] && [ -d "$LORA_PATH/checkpoint-$GS" ]; then
                    FINAL_LORA_PATH="$LORA_PATH/checkpoint-$GS"
                else
                    LATEST_CHECKPOINT=$(ls -td "$LORA_PATH"/checkpoint-* 2>/dev/null | head -1 || true)
                    if [ -n "$LATEST_CHECKPOINT" ]; then
                        FINAL_LORA_PATH="$LATEST_CHECKPOINT"
                    else
                        FINAL_LORA_PATH="$LORA_PATH"
                    fi
                fi
            else
                LATEST_CHECKPOINT=$(ls -td "$LORA_PATH"/checkpoint-* 2>/dev/null | head -1 || true)
                if [ -n "$LATEST_CHECKPOINT" ]; then
                    FINAL_LORA_PATH="$LATEST_CHECKPOINT"
                else
                    FINAL_LORA_PATH="$LORA_PATH"
                fi
            fi
        fi
    else
        # If it's not a directory, assume it's a valid path (could be a checkpoint dir or a symlink)
        FINAL_LORA_PATH="$LORA_PATH"
    fi
fi

# Run vLLM server
echo "Starting vLLM server with model: $MODEL_PATH"
echo "  Host: $SERVER_HOST"
echo "  Port: $SERVER_PORT"
echo "  Served model name: $SERVED_MODEL_NAME"
if [ -n "$FINAL_LORA_PATH" ]; then
    echo "  LoRA: $FINAL_LORA_PATH (registered as model '$SERVED_MODEL_NAME')"
fi
echo "  TP: $NUM_GPUS"
echo "  Dtype: $DTYPE"

VLLM_ARGS=(
  serve "$MODEL_PATH"
  --host "$SERVER_HOST"
  --port "$SERVER_PORT"
  --served-model-name "$SERVED_MODEL_NAME"
  --trust-remote-code
  --dtype "$DTYPE"
  --tensor-parallel-size "$NUM_GPUS"
  --seed "123"
  --disable-uvicorn-access-log
)

# LoRA: register under the served model name so model="increase" uses the adapter.
if [ -n "$FINAL_LORA_PATH" ]; then
    VLLM_ARGS+=(--enable-lora --lora-modules "${SERVED_MODEL_NAME}=${FINAL_LORA_PATH}")
fi

vllm "${VLLM_ARGS[@]}"

