#!/bin/bash
#SBATCH --job-name=finetune_ours_llama3_xsum
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:4
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export TOKENIZERS_PARALLELISM=false
PYTHON=~/.conda/envs/come/bin/python

MASTER_ADDR=localhost
MASTER_PORT=30011
GPUS_PER_NODE=4

SFT_DATASET=xsum

PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/prun/ContinuePrun-from-ShortGPT-31Layer/
PRUNED_MODEL_FILE_LIST=(
    "Meta-Llama-3.1-8B-Instruct_shortgpt_28"
    "Meta-Llama-3.1-8B-Instruct_shortgpt_25"
)

# DeepSpeed 配置
DS_CONFIG=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/ds_config.json

echo ">>> Starting batch finetune for ${#PRUNED_MODEL_FILE_LIST[@]} models..."
for PRUNED_MODEL_FILE in "${PRUNED_MODEL_FILE_LIST[@]}"; do

    echo "-------------------------------------------------------------"
    echo ">>> Processing model: ${PRUNED_MODEL_FILE}"

    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}

    # 输出目录
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/finetune/from-ShortGPT-31Layer/${SFT_DATASET}/${PRUNED_MODEL_FILE}_lr5e-6_bs2_ep3

    torchrun --nproc_per_node=${GPUS_PER_NODE} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/finetune_OtherPrunedMethod.py \
      --model_name_or_path "${PRUNED_MODEL_PATH}" \
      --sft_dataset ${SFT_DATASET} \
      --output_dir "${OUT_DIR}" \
      --deepspeed_config "${DS_CONFIG}" \
      --max_length 512 \
      --per_device_train_batch_size 2 \
      --gradient_accumulation_steps 1 \
      --num_train_epochs 3 \
      --learning_rate 5e-6 \
      --warmup_ratio 0.03 \
      --logging_steps 50 \
      --save_iter 10000 \
      --seed 0
    echo "-------------------------------------------------------------"
done

# 30%的裁剪率
PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/prun/ContinuePrun-from-ShortGPT-24Layer/
PRUNED_MODEL_FILE_LIST=(
    "Meta-Llama-3.1-8B-Instruct_shortgpt_24_shortgpt_22"
)

echo ">>> Starting batch finetune for ${#PRUNED_MODEL_FILE_LIST[@]} models..."
for PRUNED_MODEL_FILE in "${PRUNED_MODEL_FILE_LIST[@]}"; do

    echo "-------------------------------------------------------------"
    echo ">>> Processing model: ${PRUNED_MODEL_FILE}"

    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}

    输出目录
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/iter_shortgpt_output/calib_arc_challenge/llama3-8b/finetune/from-ShortGPT-24Layer/${SFT_DATASET}/${PRUNED_MODEL_FILE}_lr5e-6_bs2_ep3

    torchrun --nproc_per_node=${GPUS_PER_NODE} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/finetune_OtherPrunedMethod.py \
      --model_name_or_path "${PRUNED_MODEL_PATH}" \
      --sft_dataset ${SFT_DATASET} \
      --output_dir "${OUT_DIR}" \
      --deepspeed_config "${DS_CONFIG}" \
      --max_length 512 \
      --per_device_train_batch_size 2 \
      --gradient_accumulation_steps 1 \
      --num_train_epochs 3 \
      --learning_rate 5e-6 \
      --warmup_ratio 0.03 \
      --logging_steps 50 \
      --save_iter 10000 \
      --seed 42

    echo "-------------------------------------------------------------"
done

echo ">>> All finetuning tasks completed!"
