#!/usr/bin/env bash
set -euo pipefail

SBATCH_FILE="Framework/infer/vllm_infer.sbatch"
LLF_ROOT="LLaMA-Factory"

RUNS_ROOT="LLaMA-Factory/runs"
mkdir -p infer_logs "$RUNS_ROOT"

MODEL_PATHS=(
  # YOUR MODEL PATH
)

# Fixed datasets and raw files
PREDICT_DATASET="test_gskel"
RAW_FILE="all_test_gskel.json"

TEMPLATES=()
for path in "${MODEL_PATHS[@]}"; do
  if [[ "$path" == *"InternVL3"* ]]; then
    TEMPLATES+=("intern_vl")
  elif [[ "$path" == *"Qwen3-VL"* && "$path" == *"Instruct"* ]]; then
    TEMPLATES+=("qwen3_vl_nothink")
  elif [[ "$path" == *"Qwen3-VL"* && "$path" == *"Thinking"* ]]; then
    TEMPLATES+=("qwen3_vl")
  else
    echo "Warning: Unknown model path pattern: $path"
    TEMPLATES+=("unknown")
  fi
done

DATASET_DIR="graphAGI"
MAX_NEW_TOKENS=30000
NUM_BEAMS=1
DO_SAMPLE=false
TENSOR_PARALLEL_SIZE=1
MAX_MODEL_LEN=32768

N="${#MODEL_PATHS[@]}"
[[ "$N" -eq "${#TEMPLATES[@]}" ]] || { echo "❌ MODEL_PATHS != TEMPLATES"; exit 2; }

echo "[INFO] Total jobs: ${N}"
echo "[INFO] SBATCH_FILE=${SBATCH_FILE}"

for ((i=0; i<N; i++)); do

  MODEL_PATH="${MODEL_PATHS[$i]}"
  MODEL_NAME="$(basename "$MODEL_PATH")"
  JOB_NAME="infer_${MODEL_NAME}"

  echo "--------------------------------------------------"
  echo "[INFO] Submitting $JOB_NAME"
  echo "  MODEL_PATH      = ${MODEL_PATH}"
  echo "  PREDICT_DATASET = ${PREDICT_DATASET}"
  echo "  TEMPLATE        = ${TEMPLATES[$i]}"
  echo "  RAW_FILE        = ${RAW_FILE}"
  echo "--------------------------------------------------"

  sbatch -J "${JOB_NAME}" \
    --export=ALL,LLF_ROOT="${LLF_ROOT}",MODEL_PATH="${MODEL_PATH}",PREDICT_DATASET="${PREDICT_DATASET}",RUNS_ROOT="${RUNS_ROOT}",TEMPLATE="${TEMPLATES[$i]}",RAW_FILE="${RAW_FILE}",DATASET_DIR="${DATASET_DIR}",MAX_NEW_TOKENS="${MAX_NEW_TOKENS}",NUM_BEAMS="${NUM_BEAMS}",DO_SAMPLE="${DO_SAMPLE}",TENSOR_PARALLEL_SIZE="${TENSOR_PARALLEL_SIZE}",MAX_MODEL_LEN="${MAX_MODEL_LEN}" \
    "${SBATCH_FILE}"

  sleep 2
done

echo "🎉 All ${N} jobs submitted."
