#!/bin/bash

# 输入参数
INPUT_PARQUET_PATH=$1
MODEL_PATH=$2
NUM_OUTPUTS=$3

# 1. 获取 parquet 文件的行数
NUM_ROWS=$(python3 -c "import pyarrow.parquet as pq; print(pq.ParquetFile('${INPUT_PARQUET_PATH}').metadata.num_rows)")
echo "Total rows in parquet: ${NUM_ROWS}"

# 2. 获取空闲 GPU 的 index
# 使用 nvidia-smi 查看空闲 GPU（显存使用为0的）
VALID_GPUS=$(nvidia-smi --query-gpu=index,memory.used --format=csv,noheader,nounits | awk '$2<1000 {print $1}')
echo "Free GPUs: ${VALID_GPUS}"

# 如果没有空闲GPU，直接退出
if [ -z "$VALID_GPUS" ]; then
  echo "No free GPU available!"
  exit 1
fi

# 3. 计算每个GPU的数据分片大小
NUM_GPUS=$(echo "$VALID_GPUS" | wc -w)
SLICE_SIZE=$(( NUM_ROWS / NUM_GPUS ))
REMAINDER=$(( NUM_ROWS % NUM_GPUS ))

echo "Using $NUM_GPUS GPUs, slice size = $SLICE_SIZE, remainder = $REMAINDER"

# 4. 循环分配任务到每个GPU
BEGIN=0
for GPU in $VALID_GPUS; do
  END=$(( BEGIN + SLICE_SIZE ))
  
  # 把余数分摊到前几个GPU
  if [ $REMAINDER -gt 0 ]; then
    END=$(( END + 1 ))
    REMAINDER=$(( REMAINDER - 1 ))
  fi

  OUTPUT_PREFIX="${INPUT_PARQUET_PATH%.parquet}"
  OUTPUT_PATH="${OUTPUT_PREFIX}_${BEGIN}_${END}.json"

  echo "Launching on GPU $GPU with range [$BEGIN, $END)"

  CUDA_VISIBLE_DEVICES=$GPU \
  ~/verl_250713/.conda/bin/python ~/verl_250713/scripts/gen_vllm.py \
      --data ${INPUT_PARQUET_PATH} \
      --begin ${BEGIN} --end ${END} \
      --dataset_split train \
      --output_finish_reason \
      --num_outputs ${NUM_OUTPUTS} \
      --prompt_key prompt \
      --output_key responses \
      --model ${MODEL_PATH} \
      --max_length 2560 \
      --temperature 0.5 \
      --gpu_memory_utilization 0.97 \
      --output_json ${OUTPUT_PATH} &

  BEGIN=$END
done

wait
echo "All jobs launched!"

# bash ~/verl_250713/scripts/bon0_gen.sh \
#   ~/datasets/PRIME-RL-Eurus-2-RL-Data/validation.parquet \
#   ~/LLaMA-Factory-250514/saves_shuyan/qwen3-8B-base/prime-sft 4
