#!/bin/bash
#SBATCH --job-name=finetune_shortgpt_llama3_xsum_0.8_0.9
#SBATCH --partition=gpuB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:4
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export TOKENIZERS_PARALLELISM=false
PYTHON=~/.conda/envs/come/bin/python

MASTER_ADDR=localhost
MASTER_PORT=3002
GPUS_PER_NODE=4

SFT_DATASET=xsum


# shortgpt
PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/pruned_models/shortgpt/calib_arc_challenge
PRUNED_MODEL_FILE_LIST=(
    "llama3-8b-0.125/Meta-Llama-3.1-8B-Instruct_shortgpt_28"
    "llama3-8b-0.219/Meta-Llama-3.1-8B-Instruct_shortgpt_25"
    # "llama3-8b-0.312/Meta-Llama-3.1-8B-Instruct_shortgpt_22"
    # "llama3-8b-0.25/Meta-Llama-3.1-8B-Instruct_shortgpt_24"
    # "llama3-8b-0.031/Meta-Llama-3.1-8B-Instruct_shortgpt_31"
    # "llama3-8b-0.062/Meta-Llama-3.1-8B-Instruct_shortgpt_30"
    # "llama3-8b-0.094/Meta-Llama-3.1-8B-Instruct_shortgpt_29"
    # "llama3-8b-0.156/Meta-Llama-3.1-8B-Instruct_shortgpt_27"
    # "llama3-8b-0.188/Meta-Llama-3.1-8B-Instruct_shortgpt_26"
    # "llama3-8b-0.281/Meta-Llama-3.1-8B-Instruct_shortgpt_23"
    # "llama3-8b-0.344/Meta-Llama-3.1-8B-Instruct_shortgpt_21"
    # "llama3-8b-0.375/Meta-Llama-3.1-8B-Instruct_shortgpt_20"
    # "llama3-8b-0.406/Meta-Llama-3.1-8B-Instruct_shortgpt_19"
    # "llama3-8b-0.438/Meta-Llama-3.1-8B-Instruct_shortgpt_18"
)

# DeepSpeed 配置
DS_CONFIG=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/ds_config.json

echo ">>> Starting batch finetune for ${#PRUNED_MODEL_FILE_LIST[@]} models..."
for PRUNED_MODEL_FILE in "${PRUNED_MODEL_FILE_LIST[@]}"; do

    echo "-------------------------------------------------------------"
    echo ">>> Processing model: ${PRUNED_MODEL_FILE}"

    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}

    # 输出目录
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/shortgpt_outputs/calib_arcc/finetune/Llama3.1-8B-Instruct/${SFT_DATASET}/${PRUNED_MODEL_FILE}_lr1e-5_bs1_ep3

    torchrun --nproc_per_node=${GPUS_PER_NODE} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/finetune_OtherPrunedMethod.py \
      --model_name_or_path "${PRUNED_MODEL_PATH}" \
      --sft_dataset ${SFT_DATASET} \
      --output_dir "${OUT_DIR}" \
      --deepspeed_config "${DS_CONFIG}" \
      --max_length 512 \
      --per_device_train_batch_size 1 \
      --gradient_accumulation_steps 1 \
      --num_train_epochs 3 \
      --learning_rate 1e-5 \
      --warmup_ratio 0.03 \
      --logging_steps 50 \
      --save_iter 100000 \
      --seed 0

    echo "-------------------------------------------------------------"
done

echo ">>> All finetuning tasks completed!"
