#!/bin/bash
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0

log_dir="./results/log"
log_file="$log_dir/inference_0.out"
mkdir -p "$log_dir"
exec >"$log_file" 2>&1

model_names=("gemma-2-2b" "gemma-2-9b" "llama-3.1-8b" "llama-3.2-1b" "llama-3.2-3b")

train_datasets=("hellaswag")
test_datasets=("hellaswag")

subset_size=100
exp_num=10
metric="cosine_similarity"
ks=(4 8) # set shot k
methods=("knn_diversity" "diversity" "random" "knn" "k_means")
# methods=("knn")
dp_choices=("knn")
embs=("all-roberta-large-v1")


total_num=0
for model_name in "${model_names[@]}"; do
    for train_dataset in "${train_datasets[@]}"; do
        for test_dataset in "${test_datasets[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        # if you need to run ood, then  comment the following three lines of code
                        if [ $test_dataset != $train_dataset ]; then
                            continue
                        fi
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        elif [[ "$method" == "random" || "$method" == "diversity" ]]; then
                            exp_num_method=$exp_num
                        else
                            exp_num_method=$exp_num
                        fi
                        total_num=$((total_num + exp_num_method))
                    done
                done
            done
        done
    done
done

echo "Total number of runs: $total_num"


current_num=0
target_num=209
for model_name in "${model_names[@]}"; do
    for train_dataset in "${train_datasets[@]}"; do
        for test_dataset in "${test_datasets[@]}"; do
            for emb in "${embs[@]}"; do
                for k in "${ks[@]}"; do
                    for method in "${methods[@]}"; do
                        if [ $test_dataset != $train_dataset ]; then
                            continue
                        fi
                        if [[ "$method" == *"knn"* || "$method" == *"k_means"* ]]; then
                            exp_num_method=1
                        else
                            exp_num_method=$exp_num
                        fi
                        for ((i=0; i<exp_num_method; i++)); do
                            ((current_num++))
                            echo "Condition: $current_num / $total_num"
                            echo "Running inference for model $model_name on dataset $dataset_name with embedding $emb, seed=$i, k=$k, method=$method, permutation=$permutation"
                            if [ $current_num -ge $target_num ]; then
                                python compute_ppl.py \
                                    --model_name "$model_name" \
                                    --train_dataset "$train_dataset" \
                                    --test_dataset "$test_dataset" \
                                    --seed "$i" \
                                    --batch_size 4 \
                                    --subset_size $subset_size \
                                    --k $k \
                                    --method $method \
                                    --emb $emb \
                                    --metric $metric
                            fi
                        done
                    done
                done
            done
        done
    done
done

