#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=${2-2112}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-2}

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/hy-tmp/dc"}

## Optional args:
## $4 = CKPT_NAME (used only for naming outputs)
## $5 = CKPT_DIR (absolute path to model directory). If not provided, falls back to BASE_PATH/checkpoints/CKPT_NAME/
## Examples:
##   bash generate_data_dolly.sh /hy-tmp/dc 2112 1 gpt2-xl /hy-tmp/dc/results/gpt2/train/sft/e10-bs4-lr5e-05-G2-N1-NN1/702
##   bash generate_data_dolly.sh /hy-tmp/dc 2112 1 gpt2-xl

#CKPT_NAME=${4-"gpt2_kd/t3/768"} #24.88

#CKPT_NAME=${4-"gpt2_kd/t3/576"} #

#CKPT_NAME=${4-"hr0.01_no_rdistill_0.1B_sft1.5B_off_epochs8_ratio0.7-1_tmp1-4/t3/1416"} #25.2
#CKPT_NAME=${4-"hr0.01_no_jsd_0.1B_sft1.5B_on_epochs8_ratio0.7-1_tmp1-4/t3/944"}
#CKPT_NAME=${4-"hr0.01_no_jsd_0.1B_sft1.5B_off_epochs8_ratio0.7-1_tmp1-2/t3/890"} #27.8 #25.99
#CKPT_NAME=${4-"hr0.01_no_jsd_0.1B_sft1.5B_off_epochs8_ratio0.7-1_tmp1-4/t3/890"} #26.2 #23.57
#CKPT_NAME=${4-"jsd_0.1B_sft1.5B_off_epochs20_tmp1/3560"} #26.6 #25.19
CKPT_NAME=${4-"gpt2-base"} #
# Prefer provided CKPT_DIR if available; otherwise, use checkpoints by name
if [ -n "${5}" ]; then
  CKPT_IN="${5}"
  # Resolve to a leaf checkpoint dir if a run path is given
  if [ -f "${CKPT_IN}/model.safetensors" ] || \
     [ -f "${CKPT_IN}/model.safetensors.index.json" ] || \
     ls "${CKPT_IN}"/model-*.safetensors >/dev/null 2>&1 || \
     [ -f "${CKPT_IN}/pytorch_model.bin" ] || \
     [ -f "${CKPT_IN}/pytorch_model.bin.index.json" ]; then
    CKPT="${CKPT_IN}"
  else
    LAST_SUBDIR=$(ls -1 "${CKPT_IN}" 2>/dev/null | grep -E '^[0-9]+$' | sort -n | tail -n 1)
    if [ -n "${LAST_SUBDIR}" ] && [ -d "${CKPT_IN}/${LAST_SUBDIR}" ]; then
      CKPT="${CKPT_IN}/${LAST_SUBDIR}"
    else
      CKPT="${CKPT_IN}"
    fi
  fi
  # Derive a distinctive save name: parent/basename
  if [[ "${CKPT}" == */* ]]; then
    CKPT_PARENT=$(basename "$(dirname "${CKPT}")")
    CKPT_BASE=$(basename "${CKPT}")
    CKPT_NAME="${CKPT_PARENT}/${CKPT_BASE}"
  fi
else
  CKPT="${BASE_PATH}/checkpoints/${CKPT_NAME}/"
fi
# data
DATA_NAMES="dolly"
# Use processed Dolly JSONL so PromptDataset --json-data finds train.jsonl/valid.jsonl
DATA_DIR="${BASE_PATH}/processed_data/dolly/full/gpt2"
# hp
EVAL_BATCH_SIZE=16
# runtime
SAVE_PATH="${BASE_PATH}/processed_data/dolly/full/generate_data/"
TYPE="eval_main"


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
OPTS+=" --model-type gpt2"
# data
OPTS+=" --data-dir ${DATA_DIR}"
OPTS+=" --data-names ${DATA_NAMES}"
OPTS+=" --num-workers 0"
OPTS+=" --dev-num -1"
OPTS+=" --data-process-workers -1"
OPTS+=" --json-data"
# hp
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --max-length 512"
OPTS+=" --max-prompt-length 256"
# runtime
OPTS+=" --do-eval"
OPTS+=" --save ${SAVE_PATH}"
OPTS+=" --seed 10"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
OPTS+=" --type ${TYPE}"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 0.01"


export NCCL_DEBUG=""
export TOKENIZERS_PARALLELISM=false
export PYTHONIOENCODING=utf-8
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/tools/generate_data.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
