#!/bin/bash

# Clear or create the log files
> scripts/infer_para_sample_stdout2.log
> scripts/infer_para_sample_stderr2.log

# Define k values
# k_values=(8 16 32 64)
# k_values=(1 2 4 128 256 512 1024 2000)
k_values=(32 64 128 256 512 768 1024 2048) ## all 2048
# k_values=(64 128 256 512 1024 1536 2048 4096) ## all 4096

# Array of available GPU IDs
# gpu_ids=(0 1 2 3)
gpu_ids=(4 5 6 7)
# Define index ranges
# index_ranges=(
#     "0 200"
#     "200 400"
#     "400 600"
#     "600 800"
# )

index_ranges=(
    "0 100"
    "200 300"
    "400 500"
    "600 700"
)


# Check if we have enough GPUs
if [ ${#gpu_ids[@]} -lt ${#index_ranges[@]} ]; then
    echo "Error: Not enough GPUs specified. Found ${#gpu_ids[@]}, need ${#index_ranges[@]}."
    exit 1
fi

# Function to run inference on a specific GPU
run_inference() {
    local gpu=$1
    local start_idx=$2
    local end_idx=$3
    export CUDA_VISIBLE_DEVICES=$gpu
    python infer_llama.py \
        --naive \
        --start_idx $start_idx \
        --end_idx $end_idx \
        >>  scripts/infer_para_sample_stdout2.log 2>> scripts/infer_para_sample_stderr2.log
    for k in "${k_values[@]}"
    do
        echo "GPU $gpu: k = $k, range = $start_idx to $end_idx"
        python infer_llama.py \
            --k $k \
            --start_idx $start_idx \
            --end_idx $end_idx \
            >>  scripts/infer_para_sample_stdout2.log 2>> scripts/infer_para_sample_stderr2.log
    done
    echo "GPU $gpu completed"
}

# Run inference for each index range on a separate GPU
for i in "${!index_ranges[@]}"; do
    range=(${index_ranges[$i]})
    start_idx=${range[0]}
    end_idx=${range[1]}
    gpu=${gpu_ids[$i]}
    run_inference $gpu $start_idx $end_idx &
done

# Wait for all background processes to finish
wait

echo "All inference tasks completed"
