#!/usr/bin/env bash
set -euo pipefail

MODEL=Qwen/Qwen3-4B-Base
HOSTFILE=/horovod/generated/hostfile
DSCONF=./ds_config_zero2.json
OUT_DIR=./output/slimpajama/Qwen3-4B-Base/kl-all-mlp-0.2-600k
mkdir -p "$OUT_DIR"

NUM_LAYERS=36            
BLOCK_SIZE=4             
BASE_HEAD_DIM=128        
BASE_MLP_DIM=9728       

ATT_BLOCK_DIMS=(128 128 128 128 128)        
MLP_BLOCK_DIMS=(2560 2816 3328 4608 9216)   

TOTAL_STEPS=${#MLP_BLOCK_DIMS[@]}
(( ${#ATT_BLOCK_DIMS[@]} > TOTAL_STEPS )) && TOTAL_STEPS=${#ATT_BLOCK_DIMS[@]}
MAX_BLOCKS=$(( NUM_LAYERS / BLOCK_SIZE ))
(( TOTAL_STEPS > MAX_BLOCKS )) && TOTAL_STEPS=$MAX_BLOCKS

range_join() {  # start..end → "start,start+1,...,end"
  local start=$1 end=$2 s=""
  (( start < 0 )) && start=0
  (( end >= NUM_LAYERS )) && end=$((NUM_LAYERS-1))
  for ((k=start; k<=end; k++)); do
    if [[ -n "$s" ]]; then s+=",${k}"; else s="${k}"; fi
  done
  echo "$s"
}
repeat4() { local v=$1; echo "${v},${v},${v},${v}"; }
block_bounds() {  # j=0→28..31, j=1→24..27, ...
  local j=$1
  local start=$(( (NUM_LAYERS - BLOCK_SIZE) - (BLOCK_SIZE * j) ))
  local end=$(( start + BLOCK_SIZE - 1 ))
  echo "$start,$end"
}
log_tag_of() { echo $(( ( $1 + 1 ) * BLOCK_SIZE )); }

for ((i=0; i<TOTAL_STEPS; i++)); do
  br=$(block_bounds "$i"); b_start=${br%%,*}; b_end=${br##*,}
  (( b_start < 0 )) && break
  init_idx=$(range_join "$b_start" "$((b_start+BLOCK_SIZE-1))")
  training_idx="$init_idx"

  cum_mlp_idx=""; cum_mlp_dims=""
  cum_att_idx=""; cum_att_dims=""

  for ((j=i; j>=0; j--)); do
    brj=$(block_bounds "$j"); jb_start=${brj%%,*}; jb_end=${brj##*,}
    idx_j=$(range_join "$jb_start" "$jb_end")

    # MLP
    if (( j < ${#MLP_BLOCK_DIMS[@]} )); then
      dim_j=${MLP_BLOCK_DIMS[j]}
      if [[ -n "$dim_j" && "$dim_j" -ne "$BASE_MLP_DIM" ]]; then
        if [[ -n "$cum_mlp_idx" ]]; then
          cum_mlp_idx="${cum_mlp_idx},${idx_j}"
          cum_mlp_dims="${cum_mlp_dims},$(repeat4 "$dim_j")"
        else
          cum_mlp_idx="${idx_j}"
          cum_mlp_dims="$(repeat4 "$dim_j")"
        fi
      fi
    fi

    # ATTN
    if (( j < ${#ATT_BLOCK_DIMS[@]} )); then
      h_j=${ATT_BLOCK_DIMS[j]}
      if [[ -n "$h_j" && "$h_j" -ne "$BASE_HEAD_DIM" ]]; then
        if [[ -n "$cum_att_idx" ]]; then
          cum_att_idx="${cum_att_idx},${idx_j}"
          cum_att_dims="${cum_att_dims},$(repeat4 "$h_j")"
        else
          cum_att_idx="${idx_j}"
          cum_att_dims="$(repeat4 "$h_j")"
        fi
      fi
    fi
  done

  log_tag=$(log_tag_of "$i")

  args=(
    deepspeed --hostfile "$HOSTFILE" run_comp_kd.py
    --model_name_or_path "$MODEL"
    --dataset_name slimpajama
    --deepspeed "$DSCONF"
    --bf16 true
    --per_device_train_batch_size 8
    --per_device_eval_batch_size 1
    --gradient_accumulation_steps 1
    --num_train_epochs 1
    --do_train
    --output_dir "$OUT_DIR"
    --weight_decay 0.01
    --bf16_full_eval true
    --logging_step 10
    --report_to none
    --overwrite_output_dir
    --alpha 0.0
    --beta 1.0
    --save_strategy epoch
    --streaming true
    --learning_rate 5e-5
    --max_grad_norm 1.0
    --block_size 1024
    --streaming true
    --init_idx "$init_idx"
    --training_idx "$training_idx"
    --distill_loss kl
    --logging_first_step true
    --num_samples 600
  )

  if [[ -n "$cum_mlp_idx" ]]; then
    args+=( --mlp_idx "$cum_mlp_idx" --mlp_output_dim "$cum_mlp_dims" )
  fi
  if [[ -n "$cum_att_idx" ]]; then
    args+=( --attn_idx "$cum_att_idx" --attn_output_dim "$cum_att_dims" )
  fi

  if (( i >= 1 )); then
    args+=( --student_model_name_or_path "$OUT_DIR" )
  fi

  echo "[Run $((i+1))/${TOTAL_STEPS}] init/training=${init_idx}"
  [[ -n "$cum_mlp_idx" ]] && echo "  MLP  idx=${cum_mlp_idx}  dims=${cum_mlp_dims}"
  [[ -n "$cum_att_idx" ]] && echo "  ATTN idx=${cum_att_idx}  dims=${cum_att_dims}"

  "${args[@]}" #> "${OUT_DIR}/log-${log_tag}.txt" 2>&1
done
