set -e
set -u

SCRIPT_DIR=$(cd $(dirname $0); pwd)
WORK_DIR=$SCRIPT_DIR/..
GEN_DIR=$SCRIPT_DIR/../outputs
# mkdir -p $GEN_DIR
TASK_FLAG=default
SELF_RECORD_FILE=$SCRIPT_DIR/${TASK_FLAG}_mathtest.log

export TOKENIZERS_PARALLELISM=false
export LM_HARNESS_CACHE_PATH=$WORK_DIR/cache
export HF_DATASETS_OFFLINE=1
export TRANSFORMERS_OFFLINE=1

# multiple checkpoints
if [ $# -lt 9 ]; then
    echo "Usage: $0 <MODEL_PATH> <CHECKPOINT> <CHAT_TEMPLATE> <TEST_SET> <GREEDY> <GEN> <TEMP> <TOP_P> <SYS_PROMPT> *<ADD_POSTFIX> *<SELF_VERIFY> *<JUDGE>"
    exit 1
fi

MODEL_PATH=$1
CHECKPOINT=$2
CHAT_TEMPLATE=$3
TEST_SET=$4
GREEDY=${5:-true}
GEN=${6:-1}
TEMP=${7:-0}
TOP_P=${8:-1}
SYSTEM_PROMPT=${9:-""}
ADD_POSTFIX=${10:-""}
SELF_VERIFY=${11:-true}
JUDGE=${12:-"rule_based"}

########################################
#          Sanity check
########################################
echo "MODEL_PATH: $MODEL_PATH"
echo "CHECKPOINT: $CHECKPOINT"
echo "CHAT_TEMPLATE: $CHAT_TEMPLATE"
echo "TEST_SET: $TEST_SET"
echo "GREEDY: $GREEDY"
echo "GEN: $GEN"
echo "TEMP: $TEMP"
echo "TOP_P: $TOP_P"
echo "SYSTEM_PROMPT: $SYSTEM_PROMPT"
echo "ADD_POSTFIX: $ADD_POSTFIX"
echo "SELF_VERIFY: $SELF_VERIFY"
echo "JUDGE: $JUDGE"

MODEL_NAME=$(basename "$MODEL_PATH")
echo "MODEL_NAME: $MODEL_NAME"
DATASET_NAME=$(basename "${TEST_SET%.*}")
echo "DATASET_NAME: $DATASET_NAME"

EXPECTED_QUES_KEY="problem"
EXPECTED_ANS_KEY="answer"

echo "EXPECTED_QUES_KEY: $EXPECTED_QUES_KEY"
echo "EXPECTED_ANS_KEY: $EXPECTED_ANS_KEY"


if [[ "$JUDGE" != "llm_as_a_judge" && "$JUDGE" != "rule_based" ]]; then
    echo "Invalid JUDGE value: $JUDGE. Must be either 'llm_as_a_judge' or 'rule_based'."
    exit 1
fi

# all the checkpoints
if [[ $CHECKPOINT == "all" ]]; then
    ALL_CHECKPOINTS=$(ls $MODEL_PATH)
else
    ALL_CHECKPOINTS=( $CHECKPOINT ) # single checkpoint or 'none'
fi

# loop over all checkpoints
for CHECKPOINT_NAME in $ALL_CHECKPOINTS; do
    if [[ $CHECKPOINT_NAME == "checkpoint-"* ]]; then
        echo "Processing checkpoint: $CHECKPOINT_NAME"
        EXP_DIR=$GEN_DIR/$MODEL_NAME/$CHECKPOINT_NAME
        EVAL_MODEL_DIR=$MODEL_PATH/$CHECKPOINT_NAME
    elif [[ $CHECKPOINT_NAME == "none" ]]; then
        echo "Processing without checkpoint"
        EXP_DIR=$GEN_DIR/$MODEL_NAME
        EVAL_MODEL_DIR=$MODEL_PATH
    else
        echo "Invalid checkpoint: $CHECKPOINT_NAME"
        exit 1
    fi

    MATH_EXP_DIR=$EXP_DIR/${DATASET_NAME}/${JUDGE}/gen${GEN}_temp${TEMP}_topp${TOP_P}/$(date +%Y%m%d_%H%M%S)
    LOCK_FILE=$MATH_EXP_DIR/lock.log

    # already done or handled by other scripts
    if [ -e $LOCK_FILE ]; then
        echo "Already exists: $LOCK_FILE"
    fi

    # write to self record file
    echo $LOCK_FILE >> $SELF_RECORD_FILE

    # run
    mkdir -p $MATH_EXP_DIR
    touch $LOCK_FILE
    echo "$(date '+%Y-%m-%d %H:%M:%S')" >> $LOCK_FILE

    # 1. generate
    python3 $WORK_DIR/math_eval.py generate \
        --base_model $EVAL_MODEL_DIR \
        --chat_template_name $CHAT_TEMPLATE \
        --system_prompt "$SYSTEM_PROMPT" \
        --output_dir $MATH_EXP_DIR \
        --bf16 True \
        --data_file $TEST_SET \
        --question_key $EXPECTED_QUES_KEY \
        --add_prompt "$ADD_POSTFIX" \
        --gen_per_question "${GEN}" \
        --greedy "${GREEDY}" \
        --temperature "${TEMP}" \
        --top_p "${TOP_P}"

    # 1.1 (Optional) Self-verify
    if [ "$SELF_VERIFY" == "true" ]; then
        python3 $WORK_DIR/math_eval.py self_verify \
            --base_model $EVAL_MODEL_DIR \
            --output_file ${MATH_EXP_DIR}/generation.jsonl \
            --bf16 True \
            --question_key $EXPECTED_QUES_KEY \
            --gen_per_question "${GEN}"
    fi

    # 2. judge
    python3 $WORK_DIR/math_eval.py judge \
        --config_file ${MATH_EXP_DIR}/config.json \
        --llm_judge $JUDGE \
        --expected_ans_key $EXPECTED_ANS_KEY
    
    # 3. compute metrics
    if [ "$GREEDY" == "true" ]; then
        echo "Computing Greedy Acc ..."
        python3 $WORK_DIR/math_eval.py compute_metrics \
            --config_file ${MATH_EXP_DIR}/config.json \
            --greedy
    else
        echo "Computing Sampling Metrics ..."
        if [ "$SELF_VERIFY" == "true" ]; then
            python3 $WORK_DIR/math_eval.py compute_metrics \
                --config_file ${MATH_EXP_DIR}/config.json \
                --pass_at_k \
                --pass_ratio \
                --c_acc \
                --sc \
                --wsc
        else
            python3 $WORK_DIR/math_eval.py compute_metrics \
                --config_file ${MATH_EXP_DIR}/config.json \
                --pass_at_k \
                --pass_ratio \
                --sc
        fi
    fi

    # delete from self record file
    sed -i "\|$LOCK_FILE|d" "$SELF_RECORD_FILE"

done
