source scripts/common_setting.sh

tune_ckpt_path="sleb"

base_models=(
#  "$llama2"
#  "$llama2"
#  "$llama2"
#  "$llama2"
  "$qwen7b"
  "$qwen7b"
  "$qwen7b"
  "$qwen7b"
)

model_names=(
#  "llama2"
#  "llama2"
#  "llama2"
#  "llama2"
  "qwen7b"
  "qwen7b"
  "qwen7b"
  "qwen7b"
)

data_paths=(
  "yahma/alpaca-cleaned"
  "yahma/alpaca-cleaned"
  "yahma/alpaca-cleaned"
  "yahma/alpaca-cleaned"
)

data_names=(
  "alpaca"
  "alpaca"
  "alpaca"
  "alpaca"
)

num_remove_blocks=(
  4 8 12 16
)

lora_r=8
gpu_ids=(
  4 5 6 7
)

run_tuning_and_evaluation(){

  local gpu_id=$1
  local base_model=$2
  local model_name=$3
  local data_path=$4
  local data_name=$5
  local remove_blocks=$6

  block_order_path="baselines/SLEB/sleb_results/block_order_${model_name}.csv"

  # prune model
#  CUDA_VISIBLE_DEVICES=$gpu_id python baselines/SLEB/sleb.py \
#    --model_name ${base_model} \
#    --num_blocks 32 \
#    --num_remove_blocks ${remove_blocks} \
#    --eval_ppl False \
#    --result_file baselines/SLEB/sleb_results_${remove_blocks}.txt \
#    --eval_zeroshot False

#  echo "base_model: ${base_model}"
#  echo "Lora Config: lora_r=($lora_r),"
#  current_time=$(date "+%Y-%m-%d %H:%M:%S")
#  echo "Start tuning on gpu: $gpu_id, $current_time"
#
#  CUDA_VISIBLE_DEVICES=$gpu_id python baselines/static_finetune.py \
#     --base_model $base_model \
#     --data_name $data_name \
#     --data_path $data_path \
#     --block_order_path ${block_order_path} \
#     --num_remove_blocks ${remove_blocks} \
#     --output_dir tune_log/$model_name/${tune_ckpt_path}_rm${remove_blocks}_lora${lora_r}/$data_name/ \
#     --lora_r $lora_r \
#     --cutoff_len 512 \
#     --num_epochs 2 \
#     --learning_rate 1e-4 \
#     --batch_size 8
#
# # ---------------- eval acc ------------------
#  echo "base_model: ${base_model}"
#  current_time=$(date "+%Y-%m-%d %H:%M:%S")
#  echo "Start evaluation on gpu: $gpu_id, $current_time"
#
#  CUDA_VISIBLE_DEVICES=$gpu_id python baselines/static_eval.py \
#      --model hf \
#      --pretrained $base_model \
#      --peft=tune_log/$model_name/${tune_ckpt_path}_rm${remove_blocks}_lora${lora_r}/$data_name/ \
#      --block_order_path ${block_order_path} \
#      --num_remove_blocks ${remove_blocks} \
#      --tasks openbookqa,piqa,boolq,social_iqa,hellaswag,arc_easy,winogrande,arc_challenge \
#      --device cuda:0 \
#      --batch_size 4 \
#      --output_path results/$model_name/${tune_ckpt_path}_rm${remove_blocks}_lora${lora_r}/$data_name/
#
#  current_time=$(date "+%Y-%m-%d %H:%M:%S")
#  echo "End evaluation on gpu: $gpu_id, $current_time"

# ---------------- eval ppl ------------------
  echo "base_model: ${base_model}"
  current_time=$(date "+%Y-%m-%d %H:%M:%S")
  echo "Start evaluation on gpu: $gpu_id, $current_time"

  CUDA_VISIBLE_DEVICES=$gpu_id python eval_metric.py \
      --name ${tune_ckpt_path} \
      --base_model $base_model \
      --peft=tune_log/$model_name/${tune_ckpt_path}_rm${remove_blocks}_lora${lora_r}/$data_name/ \
      --block_order_path ${block_order_path} \
      --num_remove_blocks ${remove_blocks} \
      --tasks wikitext2,ptb \
      --batch_size 4 \
      --cutoff_len 256 \
      --output_path results/$model_name/${tune_ckpt_path}_rm${remove_blocks}_lora${lora_r}/$data_name/

  current_time=$(date "+%Y-%m-%d %H:%M:%S")
  echo "End evaluation on gpu: $gpu_id, $current_time"
}

for ((j=0; j<${#data_paths[@]}; j+=4)); do
  for i in "${!gpu_ids[@]}"; do
      idx=$((j+i))
      if ((idx>=${#data_paths[@]}));then
        break
      fi

      gpu_id=${gpu_ids[$i]}
      base_model=${base_models[$idx]}
      model_name=${model_names[$idx]}
      data_path=${data_paths[$idx]}
      data_name=${data_names[$idx]}
      remove_blocks=${num_remove_blocks[$idx]}

      run_tuning_and_evaluation "$gpu_id" "$base_model" "$model_name" "$data_path" "$data_name" "$remove_blocks" & 
  done
  wait  # Wait for all tuning and evaluation processes to finish
done