#!/bin/bash
#SBATCH --job-name=finetune_flap_llama3_xsum
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:4
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

PYTHON=~/.conda/envs/come/bin/python
export TOKENIZERS_PARALLELISM=false

export CUDA_HOME=/usr/local/cuda
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

MASTER_ADDR=localhost
MASTER_PORT=30055
GPUS_PER_NODE=4

PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/flap_baseline_outputs/prun

# DeepSpeed 配置
DS_CONFIG=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/ds_config.json

SFT_DATASET=xsum

MAX_EPOCHS=3
BATCH_SIZE=2
GRAD_ACC=1
LR=5e-5
MAX_LEN=512
LOG_STEPS=50

# 需要循环的模型
MODEL_LIST=(
    "flap_width_keep_0.7"
    "flap_width_keep_0.8"
    "flap_width_keep_0.9"
)

echo ">>> Starting batch finetune for ${#MODEL_LIST[@]} models..."

for PRUNED_MODEL_FILE in "${MODEL_LIST[@]}"; do
    echo "-------------------------------------------------------------"
    echo "Processing model: ${PRUNED_MODEL_FILE}"
    
    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}/
    
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/flap_baseline_outputs/finetune/Llama3.1-8B-Instruct/${SFT_DATASET}/${PRUNED_MODEL_FILE}_AllLabel
    mkdir -p "${OUT_DIR}"

    torchrun \
        --nproc_per_node=${GPUS_PER_NODE} \
        --master_addr ${MASTER_ADDR} \
        --master_port ${MASTER_PORT} \
        /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/flap_finetune.py \
        --model_name_or_path "${PRUNED_MODEL_PATH}" \
        --output_dir "${OUT_DIR}" \
        --deepspeed_config "${DS_CONFIG}" \
        --sft_dataset "${SFT_DATASET}" \
        --max_length ${MAX_LEN} \
        --per_device_train_batch_size ${BATCH_SIZE} \
        --gradient_accumulation_steps ${GRAD_ACC} \
        --num_train_epochs ${MAX_EPOCHS} \
        --learning_rate ${LR} \
        --logging_steps ${LOG_STEPS} \
        --save_iter -1 \
        --seed 42

    echo "Finished model: ${PRUNED_MODEL_FILE}"
    echo "-------------------------------------------------------------"
done

echo "All finetuning tasks completed!"
