#!/bin/bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0

# use env dataicl to run the classification task, otherwise bfloat16 will report an error

log_dir="./results/log"
log_file="$log_dir/inference_0.out"
mkdir -p "$log_dir"
exec >"$log_file" 2>&1


# set the test dataset and train dataset here
test_dataset_names=("glue-sst2")
train_dataset_names=("amazon_polarity" "imdb")

model_names=('llama-3.1-8b' 'gemma-2-9b')

exp_num=10
metric="cosine_similarity"
ks=(4 8)

# set the methods here
# methods=("knn" "random" "diversity" "knn_diversity" "k_means")
methods=("knn" "diversity" "knn_diversity")

dp_choices=("knn")
embs=("all-roberta-large-v1")
subset_size=100

total_num=0
for k in "${ks[@]}"; do
    for model_name in "${model_names[@]}"; do
        for train_dataset in "${train_dataset_names[@]}"; do
            for test_dataset in "${test_dataset_names[@]}"; do
                for emb in "${embs[@]}"; do
                    for method in "${methods[@]}"; do
                        # the following code is used to skip the condition that train_dataset is not equal to test_dataset
                        # if [ "$train_dataset" != "$test_dataset" ]; then
                            # continue
                        # fi
                        if [[ "$method" == *"knn_diversity"* || "$method" == *"knn"* || "$method" == *"voke_k"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        elif [[ "$method" == *"diversity"* ]]; then
                            exp_num_method=$exp_num
                        else
                            exp_num_method=$exp_num
                        fi
                        total_num=$((total_num + exp_num_method))
                    done
                done
            done
        done
    done
done

echo "Total number of runs: $total_num"

target_num=-1  

batch_size=1

current_num=0
for k in "${ks[@]}"; do
    for model_name in "${model_names[@]}"; do
        for train_dataset in "${train_dataset_names[@]}"; do
            for test_dataset in "${test_dataset_names[@]}"; do
                for emb in "${embs[@]}"; do
                    for method in "${methods[@]}"; do
                        # if [ "$train_dataset" != "$test_dataset" ]; then
                            # continue
                        # fi
                        if [[ "$method" == *"knn"* || "$method" == *"knn_diversity"* || "$method" == *"voke_k"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        for ((i=0; i<exp_num_method; i++)); do
                            ((current_num++))
                            if [[ $current_num -ge target_num ]]; then
                                if [[ "$method" =~ ^diversity_([0-9]+)$ ]]; then
                                    subset_size="${BASH_REMATCH[1]}"
                                    echo "extract the subset_size from method name: $subset_size"
                                fi
                                echo "Current running state: $current_num / $total_num"
                                echo "Running inference for model $model_name on dataset $test_dataset with embedding $emb, seed=$i, k=$k, method=$method, batch_size=$batch_size"
                                python compute_ppl.py \
                                    --model_name "$model_name" \
                                    --train_dataset "$train_dataset" \
                                    --test_dataset "$test_dataset" \
                                    --seed "$i" \
                                    --subset_size $subset_size \
                                    --k $k \
                                    --method $method \
                                    --emb $emb \
                                    --metric $metric \
                                    --batch_size $batch_size
                            fi
                        done
                    done
                done
            done
        done
    done
done
