#!/bin/bash

# user config
num_gpu=1
conda_env_eva=
conda_env_lm_eval=
bashrc_dir=
work_dir=
data_dir=
# meta-llama/Llama-2-7b-hf
# meta-llama/Meta-Llama-3.1-8B
# google/gemma-2-9b
# mistralai/Mistral-7B-v0.3
model_name=meta-llama/Llama-2-7b-hf
model_folder=${model_name##*/}
# hellaswag
# winogrande
# ai2_arc_challenge
# ai2_arc_easy
# piqa
# social_i_qa
# openbookqa
# boolq
# qa_datasets
dataset_folder=qa_datasets
# mmlu
# hellaswag
# winogrande
# arc_challenge
# arc_easy
# piqa
# social_iqa
# openbookqa
# boolq
# custom_hellaswag,custom_winogrande,custom_arc_challenge,custom_arc_easy,custom_piqa,custom_siqa,custom_openbookqa,custom_boolq
tasks=custom_hellaswag,custom_winogrande,custom_arc_challenge,custom_arc_easy,custom_piqa,custom_siqa,custom_openbookqa,custom_boolq
#num_fewshot=5
baseline_output_path=


base_dir="${data_dir}/${dataset_folder}/${model_folder}"
trainer_state_file=trainer_state.json
tasks_safe=$(echo $tasks | sed 's/,/_/g')
dirname="tasks_${tasks_safe}"

source $bashrc_dir/.bashrc
conda deactivate

cd $work_dir

# Check if variable is empty
if [[ -n "$num_fewshot" ]]; then
    num_fewshot_arg="--num_fewshot ${num_fewshot}"
    dirname="${dirname}_fewshot_${num_fewshot}"
else
    num_fewshot_arg=""
fi

# Initialize an empty array
matching_dirs=()
# Find all directories that contain trainer_state_file but not filename
while IFS= read -r -d '' dir; do
    if [ -f "$dir/$trainer_state_file" ] && [[ ! $dir =~ "checkpoint-" ]] && [[ ! -d "$dir/$dirname" ]]; then
        matching_dirs+=("$dir")
    fi
done < <(find "$base_dir" -type d -print0)

echo "matching directories found: ${matching_dirs[@]}"

# Generate a random 32-character hexadecimal hash
random_hash=$(openssl rand -hex 16)

# evaluate baseline
if [ ! -d $baseline_output_path ]; then
    conda activate $conda_env_lm_eval
    lm_eval \
        --model vllm \
        --model_args pretrained=$model_name,trust_remote_code=True,tensor_parallel_size=$num_gpu \
        --tasks $tasks \
        --output_path $baseline_output_path \
        $num_fewshot_arg
    conda deactivate
fi

for path in "${matching_dirs[@]}"; do
    # save merged checkpoint to a random subdirectory
    merged_dir="$path/$random_hash"
    conda activate $conda_env_eva
    python3 $work_dir/save_merged_cp.py \
        --cp_path $path \
        --dst_path $merged_dir \
        --device cpu
    # run eval
    conda deactivate
    conda activate $conda_env_lm_eval
    lm_eval \
        --model vllm \
        --model_args pretrained=$merged_dir,trust_remote_code=True,tensor_parallel_size=$num_gpu \
        --tasks $tasks \
        --output_path $path/$dirname \
        $num_fewshot_arg
    conda deactivate
    # remove temp directory
    rm -rf $merged_dir
done