#!/bin/bash
set -a

# Deploy a vLLM OpenAI-compatible server for guard models (e.g., Llama-Guard).
# This script mirrors deploy_vllm.sh but uses guard-appropriate defaults.

# Default values
MODEL_PATH="meta-llama/Llama-Guard-3-8B"
# Use a different default port to avoid clashing with the main rewrite/attacker server (often 30000)
SERVER_PORT="30001"
SERVER_HOST="0.0.0.0"
SERVED_MODEL_NAME="guard"
DTYPE="bfloat16"

# NOTE: vLLM supports LoRA modules via --enable-lora and --lora-modules.
# We keep this wiring for consistency, though guard deployments typically don't use LoRA.
LORA_PATH=""

usage() {
    echo "Usage: $0 [OPTIONS]"
    echo "  -m, --model MODEL_PATH          Model path/tag (default: $MODEL_PATH)"
    echo "  -l, --lora-path PATH            (optional) LoRA path"
    echo "  --served-model-name NAME        Served model name (default: $SERVED_MODEL_NAME)"
    echo "  --port PORT                     Server port (default: $SERVER_PORT)"
    echo "  --host HOST                     Server host (default: $SERVER_HOST)"
    echo "  --dtype DTYPE                   vLLM dtype (default: $DTYPE)"
    echo ""
    echo "Examples:"
    echo "  bash $0 --port 30001"
    echo "  bash $0 -m meta-llama/Llama-Guard-3-8B --served-model-name guard --port 30001"
}

# Parse command line arguments
while [[ $# -gt 0 ]]; do
    case $1 in
        -m|--model)
            MODEL_PATH="$2"
            shift 2
            ;;
        -l|--lora-path|--lora)
            LORA_PATH="$2"
            shift 2
            ;;
        --served-model-name)
            SERVED_MODEL_NAME="$2"
            shift 2
            ;;
        --port)
            SERVER_PORT="$2"
            shift 2
            ;;
        --host)
            SERVER_HOST="$2"
            shift 2
            ;;
        --dtype)
            DTYPE="$2"
            shift 2
            ;;
        -h|--help)
            usage
            exit 0
            ;;
        *)
            echo "Unknown option: $1"
            usage
            exit 1
            ;;
    esac
done

if ! command -v vllm &> /dev/null; then
    echo "Error: vllm is not available in PATH"
    exit 1
fi

# Detect GPU count for tensor parallel
if command -v nvidia-smi &> /dev/null; then
    NUM_GPUS=$(nvidia-smi --list-gpus | wc -l)
    echo "Detected $NUM_GPUS GPU(s), setting tensor parallelism to $NUM_GPUS"
else
    NUM_GPUS=1
    echo "nvidia-smi not found, defaulting to 1 GPU"
fi

# Resolve LoRA checkpoint if a directory is provided
FINAL_LORA_PATH=""
if [ -n "$LORA_PATH" ]; then
    if [ -d "$LORA_PATH" ]; then
        if [[ "$LORA_PATH" == */checkpoint-* ]]; then
            FINAL_LORA_PATH="$LORA_PATH"
        else
            TRAINER_STATE_JSON="$LORA_PATH/trainer_state.json"
            if [ -f "$TRAINER_STATE_JSON" ]; then
                GS=$(python3 -c "import json; print(json.load(open('$TRAINER_STATE_JSON','r')).get('global_step',''))" 2>/dev/null || echo "")
                if [[ "$GS" =~ ^[0-9]+$ ]] && [ -d "$LORA_PATH/checkpoint-$GS" ]; then
                    FINAL_LORA_PATH="$LORA_PATH/checkpoint-$GS"
                else
                    LATEST_CHECKPOINT=$(ls -td "$LORA_PATH"/checkpoint-* 2>/dev/null | head -1 || true)
                    if [ -n "$LATEST_CHECKPOINT" ]; then
                        FINAL_LORA_PATH="$LATEST_CHECKPOINT"
                    else
                        FINAL_LORA_PATH="$LORA_PATH"
                    fi
                fi
            else
                LATEST_CHECKPOINT=$(ls -td "$LORA_PATH"/checkpoint-* 2>/dev/null | head -1 || true)
                if [ -n "$LATEST_CHECKPOINT" ]; then
                    FINAL_LORA_PATH="$LATEST_CHECKPOINT"
                else
                    FINAL_LORA_PATH="$LORA_PATH"
                fi
            fi
        fi
    else
        # If it's not a directory, assume it's a valid path (could be a checkpoint dir or a symlink)
        FINAL_LORA_PATH="$LORA_PATH"
    fi
fi

# Run vLLM server
echo "Starting vLLM guard server with model: $MODEL_PATH"
echo "  Host: $SERVER_HOST"
echo "  Port: $SERVER_PORT"
echo "  Served model name: $SERVED_MODEL_NAME"
if [ -n "$FINAL_LORA_PATH" ]; then
    echo "  LoRA: $FINAL_LORA_PATH (registered as model '$SERVED_MODEL_NAME')"
fi
echo "  TP: $NUM_GPUS"
echo "  Dtype: $DTYPE"

VLLM_ARGS=(
  serve "$MODEL_PATH"
  --host "$SERVER_HOST"
  --port "$SERVER_PORT"
  --served-model-name "$SERVED_MODEL_NAME"
  --trust-remote-code
  --dtype "$DTYPE"
  --tensor-parallel-size "$NUM_GPUS"
  --seed "123"
  --disable-uvicorn-access-log
)

# LoRA: register under the served model name so model="$SERVED_MODEL_NAME" uses the adapter.
if [ -n "$FINAL_LORA_PATH" ]; then
    VLLM_ARGS+=(--enable-lora --lora-modules "${SERVED_MODEL_NAME}=${FINAL_LORA_PATH}")
fi

vllm "${VLLM_ARGS[@]}"



