#!/bin/bash
#SBATCH --job-name=finetune_mka_llama3_xsum
#SBATCH --partition=gpuB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:4
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

export TOKENIZERS_PARALLELISM=false
PYTHON=~/.conda/envs/come/bin/python

export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

MASTER_ADDR=localhost
MASTER_PORT=3004
GPUS_PER_NODE=4

SFT_DATASET=xsum

PRUNED_MODEL_FILE_LIST=(
    "llama3-8b-0.9/Meta-Llama-3.1-8B-Instruct_mka_28"
    "llama3-8b-0.8/Meta-Llama-3.1-8B-Instruct_mka_25"
    "llama3-8b-0.7/Meta-Llama-3.1-8B-Instruct_mka_22"
)

PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/pruned_models/mka
# DeepSpeed 配置
DS_CONFIG=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/ds_config.json

echo ">>> Starting batch finetune for ${#PRUNED_MODEL_FILE_LIST[@]} models..."
for PRUNED_MODEL_FILE in "${PRUNED_MODEL_FILE_LIST[@]}"; do

    echo "-------------------------------------------------------------"
    echo ">>> Processing model: ${PRUNED_MODEL_FILE}"

    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}

    # 输出目录
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/mka_outputs/finetune/Llama3.1-8B-Instruct/${SFT_DATASET}/${PRUNED_MODEL_FILE}

    torchrun --nproc_per_node=${GPUS_PER_NODE} --master_addr ${MASTER_ADDR} --master_port ${MASTER_PORT} /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/finetune_OtherPrunedMethod.py \
      --model_name_or_path "${PRUNED_MODEL_PATH}" \
      --sft_dataset ${SFT_DATASET} \
      --output_dir "${OUT_DIR}" \
      --deepspeed_config "${DS_CONFIG}" \
      --max_length 512 \
      --per_device_train_batch_size 1 \
      --gradient_accumulation_steps 1 \
      --num_train_epochs 3 \
      --learning_rate 5e-5 \
      --warmup_ratio 0.03 \
      --logging_steps 50 \
      --save_iter 10000 \
      --seed 42

    echo "-------------------------------------------------------------"
done

echo ">>> All finetuning tasks completed!"
