#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=${2-2012}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-2}

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/home/spectrumKD"}
# Optional 4th arg: student checkpoint path to use as model-path (HF dir)
STUDENT_CKPT_ARG=${4-""}
# Optional 5th arg: load path (DeepSpeed/extra) if needed
LOAD_PATH=${5-""}
CKPT_NAME="gpt2-base"
CKPT="${BASE_PATH}/checkpoints/${CKPT_NAME}/"
# If a student checkpoint path is provided, use it as model path and derive a name
if [ -n "${STUDENT_CKPT_ARG}" ]; then
  CKPT="${STUDENT_CKPT_ARG}"
  # derive name parent/basename
  PARENT=$(basename "$(dirname "${CKPT}")")
  BASE=$(basename "${CKPT}")
  CKPT_NAME="${PARENT}/${BASE}"
fi
# CKPT="gpt2" # download automatically
TEACHER_CKPT_NAME="e10-bs4-lr5e-05-G2-N1-NN1/14290"
TEACHER_CKPT="${BASE_PATH}/results/gpt2/train/sft/e10-bs4-lr5e-05-G2-N1-NN1/14290"
# data (use curated indices to match the checkpoint's training data)
DATA_DIR="${BASE_PATH}/processed_data/dolly/full/gpt2/"
# hp
# Defaults tuned for single A800-80G: larger micro-batch, no accumulation
BATCH_SIZE=4
LR=0.0003
GRAD_ACC=1
EVAL_BATCH_SIZE=16
NUM_WORKERS=8
# length
MAX_LENGTH=512
# runtime
SAVE_PATH="${BASE_PATH}/results/gpt2/train/kd"
# seed
SEED=10


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --teacher-model-path ${TEACHER_CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --teacher-ckpt-name ${TEACHER_CKPT_NAME}"
OPTS+=" --teacher-model-fp16"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
# OPTS+=" --gradient-checkpointing"
# data
OPTS+=" --data-dir ${DATA_DIR}"
OPTS+=" --num-workers ${NUM_WORKERS}"
OPTS+=" --dev-num 1000"
# hp
OPTS+=" --lr ${LR}"
OPTS+=" --batch-size ${BATCH_SIZE}"
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --gradient-accumulation-steps ${GRAD_ACC}"
OPTS+=" --warmup-iters 1000"
OPTS+=" --lr-decay-style cosine"
OPTS+=" --weight-decay 1e-2"
OPTS+=" --clip-grad 1.0"
OPTS+=" --epochs 10"
# Balance CE and KD equally by default; adjust as needed (e.g., 0.7 to emphasize KD)
OPTS+=" --kd-ratio 0.7"
# length
OPTS+=" --max-length ${MAX_LENGTH}"
OPTS+=" --max-prompt-length 256"
# runtime
OPTS+=" --do-train"
OPTS+=" --do-valid"
OPTS+=" --eval-gen"
OPTS+=" --save-interval -1"
OPTS+=" --eval-interval -1"
OPTS+=" --log-interval 64"
OPTS+=" --mid-log-num -1"
OPTS+=" --save ${SAVE_PATH}"
# seed
OPTS+=" --seed ${SEED}"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
# Keep dynamic loss scaling but provide a conservative starting loss-scale via CLI
OPTS+=" --loss-scale 32768"
# If a load path is provided, resume from it
if [ -n "${LOAD_PATH}" ]; then
  OPTS+=" --load ${LOAD_PATH}"
fi
# type # change to rkl/jsd/sfkl/srkl/tvd
OPTS+=" --type kd"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"


export NCCL_DEBUG=""
export WANDB_DISABLED=True
export TF_CPP_MIN_LOG_LEVEL=3
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/finetune.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
