# #!/usr/bin/env bash
set -euo pipefail

export HF_ALLOW_CODE_EVAL=1
export HF_DATASETS_TRUST_REMOTE_CODE=true

mkdir -p /home/ANONYMIZED_USER/dllms
mkdir -p /home/ANONYMIZED_USER/dllms/logs
mkdir -p /home/ANONYMIZED_USER/dllms/results

length=256
model_path='/home/ANONYMIZED_USER/WINO/models/LLaDA-1.5'
device=cuda:1
# 映射表
declare -A FEWSHOT_MAP=(
  [humaneval]=0
  [mbpp]=3
  [gsm8k]=5
  [minerva_math]=4
)

block_lengths=(8)
draft_lengths=(4)
caches=(True)
ssds=(True False)
refresh_intervals=(8)
tasks=(gsm8k minerva_math humaneval mbpp)
timestamp() { date +"%Y%m%d-%H%M%S"; }

for task in "${tasks}"; do
  num_fewshot="${FEWSHOT_MAP[$task]}"

  for block_length in "${block_lengths[@]}"; do
    for draft_length in "${draft_lengths[@]}"; do
      for cache in "${caches[@]}"; do
        for ssd in "${ssds[@]}"; do
          for refresh_interval in "${refresh_intervals[@]}"; do

            echo ">>> Running task=$task fewshot=$num_fewshot block=$block_length draft=$draft_length cache=$cache ssd=$ssd refresh=$refresh_interval"

            # 公共参数
            common_args="--tasks ${task} --num_fewshot ${num_fewshot} \
              --confirm_run_unsafe_code --model dream \
              --model_args model_path=${model_path},device=${device},gen_length=${length},draft_length=${draft_length},block_length=${block_length},kv_cache=${cache},ssd=${ssd},verbose=False,refresh_interval=${refresh_interval},show_speed=True"

            # 根据任务定制
            case "$task" in
              humaneval|mbpp)
                out_dir="/home/jza/dllms/results/${task}/base/block_${block_length}_draft_${draft_length}_cache_${cache}_ssd_${ssd}_refresh_${refresh_interval}"
                mkdir -p "$out_dir"
                log_file="/home/jza/dllms/logs/${task}_block_${block_length}_draft_${draft_length}_cache_${cache}_ssd_${ssd}_refresh_${refresh_interval}_num_fewshot_${num_fewshot}_$(timestamp).log"

                python eval.py \
                  $common_args \
                  --output_path "$out_dir" --log_samples \
                  > "$log_file" 2>&1
                ;;
              gsm8k)
                log_file="logs/base/${task}_block_${block_length}_draft_${draft_length}_cache_${cache}_ssd_${ssd}_refresh_${refresh_interval}_num_fewshot_${num_fewshot}_$(timestamp).log"

                python eval.py \
                  $common_args \
                  > "$log_file" 2>&1
                ;;
              minerva_math)
                log_file="logs/base/${task}_block_${block_length}_draft_${draft_length}_cache_${cache}_ssd_${ssd}_refresh_${refresh_interval}_num_fewshot_${num_fewshot}_$(timestamp).log"

                python eval.py --limit 0.1 \
                  $common_args \
                  > "$log_file" 2>&1
                ;;
              *)
                echo "[WARN] 没有定义 $task 的专用参数，默认只用 common_args"
                python eval.py \
                  $common_args
                ;;
            esac

          done
        done
      done
    done
  done
done
