# pip install transformers==4.49.0 lm_eval==0.4.8 accelerate==0.34.2
# pip install antlr4-python3-runtime==4.11 math_verify sympy hf_xet


# Set the environment variables first before running the command.
export HF_ALLOW_CODE_EVAL=1
export HF_DATASETS_TRUST_REMOTE_CODE=true


# # conditional likelihood estimation benchmarks
# accelerate launch eval_llada.py --tasks gpqa_main_n_shot --num_fewshot 5 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.5,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks truthfulqa_mc2 --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=2.0,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks arc_challenge --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.5,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks hellaswag --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.5,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks winogrande --num_fewshot 5 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.0,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks piqa --num_fewshot 0 --model llada_dist --batch_size 8 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.5,is_check_greedy=False,mc_num=128

# accelerate launch eval_llada.py --tasks mmlu --num_fewshot 5 --model llada_dist --batch_size 1 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.0,is_check_greedy=False,mc_num=1

# accelerate launch eval_llada.py --tasks cmmlu --num_fewshot 5 --model llada_dist --batch_size 1 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.0,is_check_greedy=False,mc_num=1

# accelerate launch eval_llada.py --tasks ceval-valid --num_fewshot 5 --model llada_dist --batch_size 1 --model_args model_path='GSAI-ML/LLaDA-8B-Base',cfg=0.0,is_check_greedy=False,mc_num=1


# conditional generation benchmarks

# export MASKING=$1  # e.g., binary_search
# export MODEL_PATH='GSAI-ML/LLaDA-8B-Instruct'
# export OUT=results/gsm8k_${MASKING}.json

# CUDA_VISIBLE_DEVICES=0 accelerate launch eval_llada.py --tasks gsm8k --model llada_dist --model_args "model_path='${MODEL_PATH}',gen_length=1024,steps=1024,block_length=8,remasking='${MASKING}'" --output_path $OUT

# CUDA_VISIBLE_DEVICES=1 accelerate launch eval_llada.py --tasks minerva_math --model llada_dist --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=1024,steps=1024,block_length=8,remasking='binary_search'

# CUDA_VISIBLE_DEVICES=2 accelerate launch eval_llada.py --tasks humaneval --model llada_dist --confirm_run_unsafe_code --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=1024,steps=1024,block_length=8,remasking='binary_search'

# CUDA_VISIBLE_DEVICES=3 accelerate launch eval_llada.py --tasks mbpp --model llada_dist --confirm_run_unsafe_code --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=1024,steps=1024,block_length=8,remasking='binary_search'

# CUDA_VISIBLE_DEVICES=4 accelerate launch eval_llada.py --tasks bbh --model llada_dist --model_args model_path='GSAI-ML/LLaDA-8B-Instruct',gen_length=1024,steps=1024,block_length=1024,remasking='binary_search'


# ===== 可配置参数 =====
MODEL_PATH=GSAI-ML/LLaDA-8B-Instruct
STEPS=1024
GEN_LENGTH=1024

# ===== 目录预创建 =====
MODEL_NAME=$(basename "$MODEL_PATH")
mkdir -p results/${MODEL_NAME}

# ===== 数据集列表 =====
TASKS=("minerva_math" "humaneval" "hellaswag")
BLOCK_LENGTHs=(8)
REMAKING_STRATEGIES=("binary_search" "low_confidence" "bst" "auto_regression")
CUDA_DEVICES=(0 1 2)

# ===== 启动循环评估 =====
for REMASKING in "${REMAKING_STRATEGIES[@]}"; do
  for BLOCK_LENGTH in "${BLOCK_LENGTHs[@]}"; do
    echo "=== Starting batch for BLOCK_LENGTH=$BLOCK_LENGTH, REMASKING=$REMASKING ==="
    device_idx=0  # 每个批次重新计数，保证每 N 个任务分配 N 张卡

    for TASK in "${TASKS[@]}"; do
      CUDA_DEVICE=${CUDA_DEVICES[$((device_idx % ${#CUDA_DEVICES[@]}))]}
      ((device_idx++))
      OUTPUT_PATH="results/${MODEL_NAME}_${TASK}/${REMASKING}_block${BLOCK_LENGTH}"

      if [[ "$TASK" == "humaneval" || "$TASK" == "mbpp" ]]; then
        CONFIRM_FLAG="--confirm_run_unsafe_code"
      else
        CONFIRM_FLAG=""
      fi
      MODEL_ARGS="model_path=${MODEL_PATH},gen_length=${GEN_LENGTH},steps=${STEPS},block_length=${BLOCK_LENGTH},remasking=${REMASKING}"
      echo "===> Running task: $TASK, block_length: $BLOCK_LENGTH, remasking: $REMASKING, device: $CUDA_DEVICE"

      CUDA_VISIBLE_DEVICES=$CUDA_DEVICE accelerate launch eval_llada.py \
        --tasks "$TASK" \
        --model llada_dist \
        --num_fewshot 0 \
        $CONFIRM_FLAG \
        --model_args "$MODEL_ARGS" \
        --output_path "$OUTPUT_PATH" &

      # 控制并发：达到设备数量就 wait 一次
      if (( device_idx % ${#CUDA_DEVICES[@]} == 0 )); then
        wait
        echo "=== Batch finished for BLOCK_LENGTH=$BLOCK_LENGTH, REMASKING=$REMASKING ==="
      fi
    done

    wait
    echo "=== Finished all tasks for BLOCK_LENGTH=$BLOCK_LENGTH, REMASKING=$REMASKING ==="
  done
done
