#!/usr/bin/env bash
# grid_search_llada.sh
# 作用：对多组 model_args 网格搜索；每次运行把命令和日志按 model_args 保存；失败不影响后续组合

set -u
set -o pipefail

######################## 用户可配置区 ########################
export HF_ALLOW_CODE_EVAL=1
export HF_DATASETS_TRUST_REMOTE_CODE=true

task="gsm8k"
num_fewshot=5
model_path="/home/ANONYMIZED_USER/WINO/models/LLaDA-8B-Instruct"

# 加速与运行参数
num_processes=4
gpu_ids=0,1,2,3
# 建议：确保只使用这些 GPU
export CUDA_VISIBLE_DEVICES="${gpu_ids}"

# 其它布尔/展示参数（注意：为保证 meta.json 是合法 JSON，用小写）
show_speed=true
ssd=true
verbose=false
kv_cache=true

# 输出根目录（按需修改）
OUT_BASE="/home/ANONYMIZED_USER/WINO/dllm-scsd/llada_grid_runs"
mkdir -p "${OUT_BASE}"

# 网格
GEN_LENGTHS=(256)
DRAFT_LENGTHS=(4 6)
BLOCK_LENGTHS=(4 8 16)
REFRESH_INTERVALS=(50 100)
######################## 用户可配置区 ########################

# 小工具：把 "a=1,b=2" 变成 "a-1__b-2" 作为安全目录名
sanitize_tag () {
  local s="$1"
  s="${s//,/__}"
  s="${s//=/-}"
  echo "$s"
}

# 将 bash 的 true/false 映射为 JSON 的 true/false（字符串）
json_bool () {
  local v="$1"
  if [[ "$v" == "true" ]]; then echo true; else echo false; fi
}

# 记录一次运行的元数据为 JSON（便于后续统计）
write_meta_json () {
  local file="$1"
  local gen="$2" draft="$3" block="$4" refresh="$5"
  cat > "$file" <<EOF
{
  "task": "${task}",
  "num_fewshot": ${num_fewshot},
  "model_path": "${model_path}",
  "gen_length": ${gen},
  "draft_length": ${draft},
  "steps": ${gen},
  "block_length": ${block},
  "refresh_interval": ${refresh},
  "show_speed": $(json_bool "${show_speed}"),
  "ssd": $(json_bool "${ssd}"),
  "verbose": $(json_bool "${verbose}"),
  "kv_cache": $(json_bool "${kv_cache}"),
  "num_processes": ${num_processes},
  "gpu_ids": "${gpu_ids}"
}
EOF
}

# 主循环
for gen in "${GEN_LENGTHS[@]}"; do
  for draft in "${DRAFT_LENGTHS[@]}"; do
    for block in "${BLOCK_LENGTHS[@]}"; do
      for refresh in "${REFRESH_INTERVALS[@]}"; do

        # 注意：steps 跟随 gen_length（与你原始脚本一致）
        model_args="model_path=${model_path},gen_length=${gen},draft_length=${draft},steps=${gen},block_length=${block},refresh_interval=${refresh},show_speed=${show_speed},ssd=${ssd},verbose=${verbose},kv_cache=${kv_cache}"

        # 目录名按 model_args 构造，方便一眼看配置
        tag="gen_length=${gen},draft_length=${draft},steps=${gen},block_length=${block},refresh_interval=${refresh},show_speed=${show_speed},ssd=${ssd},verbose=${verbose},kv_cache=${kv_cache}"
        safe_tag="$(sanitize_tag "$tag")"
        run_dir="${OUT_BASE}/${task}/${safe_tag}"
        mkdir -p "${run_dir}"

        cmd="conda run -n llada accelerate launch --num_processes ${num_processes} --gpu_ids ${gpu_ids} \
/home/ANONYMIZED_USER/WINO/dllm-scsd/eval_model/eval_llada.py --tasks ${task} --num_fewshot ${num_fewshot} \
--confirm_run_unsafe_code --model llada_dist \
--model_args ${model_args}"

        # 保存命令与元数据
        echo "${cmd}" > "${run_dir}/command.txt"
        write_meta_json "${run_dir}/meta.json" "${gen}" "${draft}" "${block}" "${refresh}"

        echo ">>> 开始运行：${safe_tag}"
        # 执行并分别记录 stdout / stderr
        bash -lc "${cmd}" > >(tee "${run_dir}/stdout.log") 2> >(tee "${run_dir}/stderr.log" >&2)
        exit_code=$?

        echo "${exit_code}" > "${run_dir}/EXIT_CODE"

        if [[ ${exit_code} -ne 0 ]]; then
          echo "!!! 组合 ${safe_tag} 运行失败（退出码 ${exit_code}），错误已写入 ${run_dir}/stderr.log，继续下一个组合…"
          tail -n 200 "${run_dir}/stderr.log" > "${run_dir}/error_summary_tail200.txt" || true
        else
          echo ">>> 组合 ${safe_tag} 运行完成"
        fi

      done
    done
  done
done

echo "✅ 网格搜索已全部完成。结果位于：${OUT_BASE}/${task}"