#! /bin/bash

# KD training with per-epoch dataset schedule for Dolly spectrum.
# Epochs 1-4 -> gpt2_spectrum/gpt2
# Epoch 5    -> gpt2_spectrum/s_1/gpt2
# Epoch 6    -> gpt2_spectrum/s_2/gpt2
# Epoch 7    -> gpt2_spectrum/s_3/gpt2
# Epoch 8    -> gpt2_spectrum/s_4/gpt2
# Epoch 9    -> gpt2_spectrum/s_5/gpt2
# Epoch 10   -> gpt2_spectrum/s_6/gpt2
# Epoch 11   -> gpt2_spectrum/s_7/gpt2
# Epoch 12   -> gpt2_spectrum/s_8/gpt2
# Temperature-KLD linearly ramps 1.0 -> 2.0 across 12 epochs. Loss type: tfkl.
# Usage:
#   bash kd_multistage.sh /home/spectrumKD <port> <gpus> [student_init_ckpt]

MASTER_ADDR=localhost
MASTER_PORT=${2-2012}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-4}

# Detect visible GPUs and auto-adjust nproc_per_node to avoid invalid device ordinal
if [ -n "${CUDA_VISIBLE_DEVICES}" ]; then
  IFS=',' read -ra DEV_ARR <<< "${CUDA_VISIBLE_DEVICES}"
  NUM_VISIBLE=${#DEV_ARR[@]}
else
  NUM_VISIBLE=$(nvidia-smi -L 2>/dev/null | wc -l | awk '{print $1}')
  NUM_VISIBLE=${NUM_VISIBLE:-1}
fi
if [ "${GPUS_PER_NODE}" -gt "${NUM_VISIBLE}" ]; then
  echo "[warn] Requested GPUs (${GPUS_PER_NODE}) exceeds visible GPUs (${NUM_VISIBLE}). Using ${NUM_VISIBLE}."
  GPUS_PER_NODE=${NUM_VISIBLE}
fi

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/home/spectrumKD"}
STUDENT_INIT=${4-"${BASE_PATH}/checkpoints/gpt2-base"}
# derive ckpt-name from student init path
CKPT_PARENT=$(basename "$(dirname "${STUDENT_INIT}")")
CKPT_BASE=$(basename "${STUDENT_INIT}")
CKPT_NAME="${CKPT_PARENT}/${CKPT_BASE}"

# teacher (finetuned gpt2-xl) - prepared as local checkpoint
TEACHER_CKPT_NAME="gpt-xl-teacher"
TEACHER_CKPT="${BASE_PATH}/checkpoints/gpt-xl-teacher"

# hp (tuned for 2x A800-80G; will also work on 1 GPU)
BATCH_SIZE=4
LR=0.0003
GRAD_ACC=1
EVAL_BATCH_SIZE=16
NUM_WORKERS=8
MAX_LENGTH=512

# training schedule
TOTAL_EPOCHS=12

# save
SAVE_PATH="${BASE_PATH}/results/gpt2/train/gpt2_multistage/${CKPT_NAME}"
SEED=10

export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}

# function to find latest checkpoint step under SAVE_DIR.
# Supports two layouts: "<run_dir>/<step>" (e.g., e1-.../14290) and direct "<step>".
# Returns a relative path to SAVE_DIR (either "<run>/<step>" or "<step>") or empty if none.
latest_step_dir() {
  local d="$1"
  local c2 c1
  c2=$(find "$d" -mindepth 2 -maxdepth 2 -type d -regex '.*/[0-9]+' 2>/dev/null | awk -v base="$d/" '{ sub(base, "", $0); print $0 }')
  c1=$(find "$d" -mindepth 1 -maxdepth 1 -type d -regex '.*/[0-9]+' 2>/dev/null | awk -v base="$d/" '{ sub(base, "", $0); print $0 }')
  printf '%s\n%s\n' "$c2" "$c1" | awk 'NF' | awk -F'/' '{print $0"\t"$NF}' | sort -k2,2n | tail -1 | cut -f1
}

mkdir -p ${SAVE_PATH}

# helper: map epoch -> dataset directory
dataset_dir_for_epoch() {
  local e=$1
  if [ "$e" -le 4 ]; then
    echo "${BASE_PATH}/processed_data/dolly/full/gpt2_spectrum/gpt2/"
  else
    local idx=$((e-4))
    echo "${BASE_PATH}/processed_data/dolly/full/gpt2_spectrum/s_${idx}/gpt2/"
  fi
}

# static options not changing across epochs
BASE_OPTS=""
BASE_OPTS+=" --base-path ${BASE_PATH}"
BASE_OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
BASE_OPTS+=" --ckpt-name ${CKPT_NAME}"
BASE_OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
BASE_OPTS+=" --teacher-model-fp16"
BASE_OPTS+=" --n-gpu ${GPUS_PER_NODE}"
BASE_OPTS+=" --num-workers ${NUM_WORKERS}"
BASE_OPTS+=" --dev-num 1000"
BASE_OPTS+=" --lr ${LR}"
BASE_OPTS+=" --batch-size ${BATCH_SIZE}"
BASE_OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
BASE_OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
BASE_OPTS+=" --warmup-iters 1000"
BASE_OPTS+=" --lr-decay-style cosine"
BASE_OPTS+=" --weight-decay 1e-2"
BASE_OPTS+=" --clip-grad 1.0"
BASE_OPTS+=" --epochs 1"
BASE_OPTS+=" --kd-ratio 0.7"
BASE_OPTS+=" --max-length ${MAX_LENGTH}"
BASE_OPTS+=" --max-prompt-length 256"
BASE_OPTS+=" --do-train"
BASE_OPTS+=" --do-valid"
BASE_OPTS+=" --eval-gen"
BASE_OPTS+=" --save-interval -1"
BASE_OPTS+=" --eval-interval -1"
BASE_OPTS+=" --log-interval 64"
BASE_OPTS+=" --mid-log-num -1"
BASE_OPTS+=" --save ${SAVE_PATH}"
BASE_OPTS+=" --seed ${SEED}"
BASE_OPTS+=" --deepspeed"
BASE_OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
BASE_OPTS+=" --loss-scale 32768"
BASE_OPTS+=" --type tfkl"
BASE_OPTS+=" --do-sample"
BASE_OPTS+=" --top-k 0"
BASE_OPTS+=" --top-p 1.0"
BASE_OPTS+=" --temperature 1.0"

for E in $(seq 1 ${TOTAL_EPOCHS}); do
  # Linear schedule from 1.0 -> 2.0 across TOTAL_EPOCHS
  # T(E) = 1.0 + (E-1) * (2.0-1.0)/(TOTAL_EPOCHS-1)
  if [ ${TOTAL_EPOCHS} -gt 1 ]; then
    TEMPERATURE_KLD=$(python - <<PY
E=${E}
N=${TOTAL_EPOCHS}
print(f"{1.0 + (E-1)*(2.0-1.0)/(N-1):.6f}")
PY
)
  else
    TEMPERATURE_KLD=1.0
  fi

  DATA_DIR=$(dataset_dir_for_epoch ${E})

  LAST=$(latest_step_dir "${SAVE_PATH}")
  if [ -n "${LAST}" ]; then
    MODEL_PATH="${SAVE_PATH}/${LAST}"
  else
    MODEL_PATH="${STUDENT_INIT}"
  fi

  RUN_OPTS="${BASE_OPTS} --model-path ${MODEL_PATH} --data-dir ${DATA_DIR} --temperature-kld ${TEMPERATURE_KLD}"

  CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/finetune.py ${RUN_OPTS}"

  echo "\n===== Epoch ${E}/${TOTAL_EPOCHS} ====="
  echo "[info] model-path: ${MODEL_PATH}"
  echo "[info] data-dir: ${DATA_DIR}"
  echo "[info] temperature-kld: ${TEMPERATURE_KLD}"
  echo ${CMD}
  ${CMD}
done

