source scripts/common_setting.sh

tune_ckpt_path="buddy"

base_models=(
  "$llama2"
)

model_names=(
  "llama2"
)

data_paths=(
  "yahma/alpaca-cleaned"
)

data_names=(
  "alpaca"
)

lora_rs=(8)
gpu_ids=(7 1 2 3)

run_tuning_and_evaluation(){

  local gpu_id=$1
  local base_model=$2
  local model_name=$3
  local data_path=$4
  local data_name=$5
  local lora_r=$6

  echo "base_model: ${base_model}"
  current_time=$(date "+%Y-%m-%d %H:%M:%S")
  echo "Start tuning on gpu: $gpu_id, $current_time"

  CUDA_VISIBLE_DEVICES=$gpu_id python train_budget_predictor.py \
     --base_model $base_model \
     --data_name $data_name \
     --data_path $data_path \
     --lambda_reg 0.1 \
     --sensitivity_type taylor \
     --sensitivity_path utils/sensitivity/${model_name}_output/taylor/block_score_all.csv \
     --output_dir tune_log/$model_name/${tune_ckpt_path}_${lora_r}/$data_name/ \
     --cutoff_len 512 \
     --num_epochs 1 \
     --learning_rate 1e-5 \
     --batch_size 2
}

for ((j=0; j<${#data_paths[@]}; j+=4)); do
  for i in "${!gpu_ids[@]}"; do
      idx=$((j+i))
      if ((idx>=${#data_paths[@]}));then
        break
      fi

      gpu_id=${gpu_ids[$i]}
      base_model=${base_models[$idx]}
      model_name=${model_names[$idx]}
      data_path=${data_paths[$idx]}
      data_name=${data_names[$idx]}
      lora_r=${lora_rs[$idx]}

      run_tuning_and_evaluation "$gpu_id" "$base_model" "$model_name" "$data_path" "$data_name" "$lora_r" &
  done
  wait  # Wait for all tuning and evaluation processes to finish
done