#!/bin/bash

# gpu_list="${CUDA_VISIBLE_DEVICES:-0}"
# IFS=',' read -ra GPULIST <<< "$gpu_list"

# CHUNKS=${#GPULIST[@]}
CHUNKS=3
CKPT="$1"  
# the ckpt starts with ming{SIZE}b, extract the SIZE (maybe 1.8 or 7) and save it to SIZE variable
SIZE=$(echo $CKPT | grep -oP 'ming\K[0-9.]+(?=b)')

domain="math"

if [[ $CKPT == *"llama7b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/llama2_7b_chat
    conv_mode="llama2"
elif [[ $CKPT == *"llama8b"* ]]; then
    MODEL_BASE=/mnt/petrelfs/usr/models/Meta-Llama-3-8B-Instruct
    conv_mode="llama3"
else
    MODEL_BASE=/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-${SIZE}B-Chat
    conv_mode="qwen"
fi

DATASET_PATH=s3://bucket/datasets/diverse_domain/test/${domain}.json
LOGS_BASE_PATH="./logs/diverse"

MODEL_PATH=/mnt/petrelfs/usr/checkpoints/${CKPT}


while [ ! -f "${MODEL_PATH}/adapter_config.json" ]; do
    echo "Waiting for ${MODEL_PATH}/adapter_config.json to appear..."
    sleep 60
done

version="woffn"
mkdir -p ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}

for IDX in $(seq 0 $((CHUNKS-1))); do
    srun -p partition --quotatype=auto  --gres=gpu:1 -o ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/${CHUNKS}_${IDX}.infer.log python -m ming.eval.model_diverse_gen \
        --model-path ${MODEL_PATH} \
        --model-base ${MODEL_BASE} \
        --question-file $DATASET_PATH \
        --answers-file ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/${CHUNKS}_${IDX}.jsonl \
        --s3-answers-file s3://bucket/logs/diverse/${domain}/${CKPT}_${version}/${CHUNKS}_${IDX}.jsonl \
        --num-chunks $CHUNKS \
        --chunk-idx $IDX \
        --temperature 0 \
        --conv-mode ${conv_mode} \
        --keep-local \
        --use-logit-bias \
        --infer-answer \
        --unload-ffn \
        --resume &
    sleep 1
done

wait

output_file=${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/merge.jsonl

# Clear out the output file if it exists.
> "$output_file"

# Loop through the indices and concatenate each file.
for IDX in $(seq 0 $((CHUNKS-1))); do
    cat ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/${CHUNKS}_${IDX}.jsonl >> "$output_file"
done

# Evaluate
srun -p partition -o ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/eval.log python -m ming.eval.eval_em \
    --input_file ${LOGS_BASE_PATH}/${domain}/${CKPT}-${version}/merge.jsonl \

