#!/bin/bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0
log_dir="./results/log"
log_file="$log_dir/inference_0.out"
mkdir -p "$log_dir"
exec >"$log_file" 2>&1
export PYTHONIOENCODING=utf-8
export LANG=zh_CN.UTF-8

modelabbrs=("llama-3.1-8b")
train_datasets=("squad" "sciq")
test_datasets=("squad" "sciq") 

subset_size=100
exp_num=10
metric="cosine_similarity"
ks=(4)
methods=("k_means")
embs=("all-roberta-large-v1")
permutations=(1)
freq=32
max_new_tokens=1024
decoding="greedy"

total_num=0
for train_dataset in "${train_datasets[@]}"; do
    for test_dataset in "${test_datasets[@]}"; do
        for modelabbr in "${modelabbrs[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [[ "$train_dataset" != "$test_dataset" ]]; then
                            continue
                        fi
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi  
                        total_num=$((total_num + exp_num_method))
                    done
                done
            done
        done
    done
done

echo "Total number of runs: $total_num"
                        
target_num=-1

current_num=0
for train_dataset in "${train_datasets[@]}"; do
    for test_dataset in "${test_datasets[@]}"; do
        for modelabbr in "${modelabbrs[@]}"; do
            modelname="/home/amax/exp/huggingface/transformers/${modelabbr}"
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [[ "$train_dataset" != "$test_dataset" ]]; then
                            continue
                        fi
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        for ((i=0; i<exp_num_method; i++)); do
                            for permutation in "${permutations[@]}"; do
                                ((current_num++))
                                if [[ "$method" =~ ^diversity_([0-9]+)$ ]]; then
                                    subset_size="${BASH_REMATCH[1]}"
                                    echo "Extracted subset_size from method name $method: $subset_size"
                                fi
                                if [[ "$current_num" -ge "$target_num" ]]; then        
                                    echo "Current run status: $current_num / $total_num"
                                    DECODING_FLAGS=$(python decoding_args_helper.py $decoding)
                                    echo "Decoding Flags: $DECODING_FLAGS"
                                    python -u fast_inference.py \
                                        --model_path "$modelname" \
                                        --train_dataset "$train_dataset" \
                                        --test_dataset "$test_dataset" \
                                        --prompt_template_style "$test_dataset" \
                                        --subset_size "$subset_size" \
                                        --k "$k" \
                                        --exp_num "$i" \
                                        --method "$method" \
                                        --emb "$emb" \
                                        --metric "$metric" \
                                        --permutation "$permutation" \
                                        --max_new_tokens $max_new_tokens \
                                        $DECODING_FLAGS --freq "$freq"
                                fi
                            done
                        done
                    done
                done
            done
        done
    done
done

