#!/bin/bash

# gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
# IFS=',' read -ra GPULIST <<< "$gpu_list"

# CHUNKS=${#GPULIST[@]}
CHUNKS=3
CKPT="$1"  
# the ckpt starts with ming{SIZE}b, extract the SIZE (maybe 1.8 or 7) and save it to SIZE variable
SIZE=$(echo $CKPT | grep -oP 'ming\K[0-9.]+(?=b)')

if [[ $CKPT == *"llama7b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/llama2_7b_chat
    conv_mode="llama2"
elif [[ $CKPT == *"llama8b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/Meta-Llama-3-8B-Instruct
    conv_mode="llama3"
else
    MODEL_BASE=/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-${SIZE}B-Chat
    conv_mode="qwen"
fi
DATASET_PATH=s3://bucket/datasets/diverse_domain/test2/bbh.json
LOGS_BASE_PATH="./logs/diverse"

MODEL_PATH=/mnt/petrelfs/usr/checkpoints/${CKPT}


# while [ ! -f "${MODEL_PATH}/model.safetensors" ]; do
#     echo "Waiting for ${MODEL_PATH}/model.safetensors to appear..."
#     sleep 60
# done

version="woffn"
mkdir -p ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}

for IDX in $(seq 0 $((CHUNKS-1))); do
    srun -p partition --quotatype=auto  --gres=gpu:1 -o ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/${CHUNKS}_${IDX}.infer.log python -m ming.eval.model_diverse_gen \
        --model-path ${MODEL_PATH} \
        --question-file $DATASET_PATH \
        --answers-file ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/${CHUNKS}_${IDX}.jsonl \
        --s3-answers-file s3://bucket/logs/diverse/${domain}/${CKPT}_${version}_${IDX}.jsonl \
        --num-chunks $CHUNKS \
        --chunk-idx $IDX \
        --temperature 0 \
        --conv-mode ${conv_mode} \
        --max-tokens 1024 \
        --keep-local \
        --unload-ffn \
        --resume &
    sleep 1
done

wait

output_file=${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/merge.jsonl

# Clear out the output file if it exists.
> "$output_file"

# Loop through the indices and concatenate each file.
for IDX in $(seq 0 $((CHUNKS-1))); do
    cat ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/${CHUNKS}_${IDX}.jsonl >> "$output_file"
done

# Evaluate
srun -p partition -o ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/eval.log python -m ming.eval.eval_em \
    --input_file ${LOGS_BASE_PATH}/bbh/${CKPT}_${version}/merge.jsonl 


