#!/bin/bash
set -ex

MODEL_PATH=$1
ALL_EVALS=$2  # Pass "--all-evals" to run all evaluations

export BASE_DATA_PATH=./data/

declare -A TASKS
TASKS=(
    [mmlu]="minimal_multitask.eval.mmlu.run_mmlu_eval"
    [gsm]="minimal_multitask.eval.gsm.run_eval"
    [bbh]="minimal_multitask.eval.bbh.run_eval"
    [tydiqa]="minimal_multitask.eval.tydiqa.run_eval"
    [codex]="minimal_multitask.eval.codex_humaneval.run_eval"
    [squad]="minimal_multitask.eval.squad.run_squad_eval"
    # [alpaca]="minimal_multitask.eval.alpaca_eval.run_alpaca_eval"
)

# if qwen or llama3 are in model_path, then use create_prompt_with_huggingface_tokenizer_template else create_prompt_with_tulu_chat_format
if [[ "$MODEL_PATH" == *"qwen"* || "$MODEL_PATH" == *"llama3"* ]]; then
    CHAT_FORMAT_FUNCTION="minimal_multitask.eval.templates.create_prompt_with_huggingface_tokenizer_template"
else
    CHAT_FORMAT_FUNCTION="minimal_multitask.eval.templates.create_prompt_with_tulu_chat_format"
fi

# Determine which evaluations to run
RUN_ALL=false
if [[ "$ALL_EVALS" == "--all-evals" ]]; then
    RUN_ALL=true
fi

for TASK in "${!TASKS[@]}"; do
    if $RUN_ALL || [[ "$MODEL_PATH" == *"$TASK"* ]]; then
        case $TASK in
            mmlu)
                python -m ${TASKS[$TASK]} \
                    --ntrain 0 \
                    --data_dir ${BASE_DATA_PATH}/eval/mmlu/ \
                    --save_dir ${MODEL_PATH}/eval_mmlu \
                    --model_name_or_path ${MODEL_PATH} \
                    --eval_batch_size 1 \
                    --use_chat_format \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION}
                ;;
            gsm)
                python -m ${TASKS[$TASK]} \
                    --data_dir ${BASE_DATA_PATH}/eval/gsm/ \
                    --save_dir ${MODEL_PATH}/eval_gsm \
                    --model_name_or_path ${MODEL_PATH} \
                    --n_shot 8 \
                    --use_chat_format \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION} \
                    --use_vllm
                ;;
            bbh)
                python -m ${TASKS[$TASK]} \
                    --data_dir ${BASE_DATA_PATH}/eval/bbh \
                    --save_dir ${MODEL_PATH}/eval_bbh \
                    --model_name_or_path ${MODEL_PATH} \
                    --use_vllm \
                    --use_chat_format \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION}
                ;;
            tydiqa)
                python -m ${TASKS[$TASK]} \
                    --data_dir ${BASE_DATA_PATH}/eval/tydiqa/ \
                    --n_shot 1 \
                    --max_context_length 512 \
                    --save_dir ${MODEL_PATH}/eval_tydiqa \
                    --model_name_or_path ${MODEL_PATH} \
                    --eval_batch_size 20 \
                    --use_vllm \
                    --use_chat_format \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION}
                ;;
            codex)
                python -m ${TASKS[$TASK]} \
                    --data_file ${BASE_DATA_PATH}/eval/codex_humaneval/HumanEval.jsonl.gz  \
                    --data_file_hep ${BASE_DATA_PATH}/eval/codex_humaneval/humanevalpack.jsonl  \
                    --use_chat_format \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION} \
                    --eval_pass_at_ks 10 \
                    --unbiased_sampling_size_n 10 \
                    --temperature 0.8 \
                    --save_dir ${MODEL_PATH}/eval_humaneval \
                    --model_name_or_path ${MODEL_PATH} \
                    --use_vllm
                ;;
            squad)
                mkdir -p ${MODEL_PATH}/eval_squad
                python -m ${TASKS[$TASK]} \
                    --model_name_or_path ${MODEL_PATH} \
                    --output_file "${MODEL_PATH}/eval_squad/predictions.json" \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION} \
                    --metrics_file "${MODEL_PATH}/eval_squad/metrics.json" \
                    --generation_file "${MODEL_PATH}/eval_squad/generation.json" \
                    --use_vllm
                ;;
            alpaca)
                python -m ${TASKS[$TASK]} \
                    --save_dir ${MODEL_PATH}/eval_alpaca \
                    --chat_formatting_function ${CHAT_FORMAT_FUNCTION} \
                    --model_name_or_path ${MODEL_PATH} \
                    --use_vllm
                ;;
        esac
    fi
done

echo "Done evaluation!"
