#! /bin/bash

MASTER_ADDR=localhost
MASTER_PORT=${2-2113}
NNODES=1
NODE_RANK=0
GPUS_PER_NODE=${3-1}

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE \
                  --nnodes $NNODES \
                  --node_rank $NODE_RANK \
                  --master_addr $MASTER_ADDR \
                  --master_port $MASTER_PORT"

# model
BASE_PATH=${1-"/home/spectrumKD"}
# Use GPT-2 XL checkpoint directory or an absolute fine-tuned path
CKPT_ARG=${4-"gpt2-xl"}

# If CKPT_ARG looks like a path (contains a slash), treat it as the full checkpoint path.
# Otherwise, assume it's a checkpoint name under ${BASE_PATH}/checkpoints/
if [[ "${CKPT_ARG}" == */* ]]; then
CKPT="${CKPT_ARG}"
# Save under parent-dir/basename, e.g., e10-bs4-lr5e-05-G2-N1-NN1/14290
CKPT_PARENT=$(basename "$(dirname "${CKPT_ARG}")")
CKPT_BASE=$(basename "${CKPT_ARG}")
CKPT_NAME="${CKPT_PARENT}/${CKPT_BASE}"
else
CKPT_NAME="${CKPT_ARG}"
CKPT="${BASE_PATH}/checkpoints/${CKPT_NAME}"
fi


# data
DATA_NAMES="dolly"
# Use processed JSONL so PromptDataset --json-data finds valid.jsonl
DATA_DIR="${BASE_PATH}/processed_data/dolly/full/gpt2"
# hp
EVAL_BATCH_SIZE=16
# runtime
SAVE_PATH="${BASE_PATH}/results/gpt2/eval_main/"
TYPE="eval_main"


OPTS=""
# model
OPTS+=" --base-path ${BASE_PATH}"
OPTS+=" --model-path ${CKPT}"
OPTS+=" --ckpt-name ${CKPT_NAME}"
OPTS+=" --n-gpu ${GPUS_PER_NODE}"
OPTS+=" --model-type gpt2"
# data
OPTS+=" --data-dir ${DATA_DIR}"
OPTS+=" --data-names ${DATA_NAMES}"
OPTS+=" --num-workers 4"
OPTS+=" --dev-num -1"
OPTS+=" --data-process-workers -1"
OPTS+=" --json-data"
# hp
OPTS+=" --eval-batch-size ${EVAL_BATCH_SIZE}"
OPTS+=" --max-length 512"
OPTS+=" --max-prompt-length 256"
# runtime
OPTS+=" --do-eval"
OPTS+=" --save ${SAVE_PATH}"
OPTS+=" --seed 10"
# deepspeed
OPTS+=" --deepspeed"
OPTS+=" --deepspeed_config ${BASE_PATH}/configs/deepspeed/ds_config.json"
OPTS+=" --type ${TYPE}"
# gen
OPTS+=" --do-sample"
OPTS+=" --top-k 0"
OPTS+=" --top-p 1.0"
OPTS+=" --temperature 1.0"


export NCCL_DEBUG=""
export TOKENIZERS_PARALLELISM=false
export PYTHONIOENCODING=utf-8
export PYTHONPATH=${BASE_PATH}
CMD="torchrun ${DISTRIBUTED_ARGS} ${BASE_PATH}/evaluate.py ${OPTS} $@"

echo ${CMD}
echo "PYTHONPATH=${PYTHONPATH}"
mkdir -p ${SAVE_PATH}
${CMD}
