#!/bin/bash
#SBATCH --job-name=finetune_tela_depthonly
#SBATCH --partition=lvjq
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=6
#SBATCH --gres=gpu:4
#SBATCH -o %J.out
#SBATCH -e %J.err

module load anaconda3
source activate come

PYTHON=~/.conda/envs/come/bin/python
export TOKENIZERS_PARALLELISM=false

MASTER_ADDR=localhost
MASTER_PORT=2016
GPUS_PER_NODE=4

SFT_DATASET=mmlu

# 剪枝后模型（HF 格式）基础路径
PRUNED_MODEL_BASE_PATH=/TO/MY/PATH/code/Understanding_Performance_Collapse/tale_outputs/prun/Llama3.1-8B-Instruct/

# 需要循环的模型
MODEL_LIST=(
    "tau1_2-tau2_2-PWidthCoarse_0.12-PWidthFine_0.06-TargetSparsity_0.5-Iters_1000-NewPrun1-DepthOnly/tale_step_0001_s0.030"
    "tau1_2-tau2_2-PWidthCoarse_0.12-PWidthFine_0.06-TargetSparsity_0.5-Iters_1000-NewPrun1-DepthOnly/tale_step_0002_s0.061"
    "tau1_2-tau2_2-PWidthCoarse_0.12-PWidthFine_0.06-TargetSparsity_0.5-Iters_1000-NewPrun1-DepthOnly/tale_step_0003_s0.091"
    "tau1_2-tau2_2-PWidthCoarse_0.12-PWidthFine_0.06-TargetSparsity_0.5-Iters_1000-NewPrun1-DepthOnly/tale_step_0004_s0.121"
    "tau1_2-tau2_2-PWidthCoarse_0.12-PWidthFine_0.06-TargetSparsity_0.5-Iters_1000-NewPrun1-DepthOnly/tale_step_0006_s0.212"
)

# DeepSpeed 配置
DS_CONFIG=/TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/ds_config.json

echo ">>> Starting batch finetune for ${#MODEL_LIST[@]} models..."

for PRUNED_MODEL_FILE in "${MODEL_LIST[@]}"; do
    echo "-------------------------------------------------------------"
    echo ">>> Processing model: ${PRUNED_MODEL_FILE}"
    
    PRUNED_MODEL_PATH=${PRUNED_MODEL_BASE_PATH}/${PRUNED_MODEL_FILE}/
    
    OUT_DIR=/TO/MY/PATH/code/Understanding_Performance_Collapse/tale_outputs/finetune/Llama3.1-8B-Instruct/${SFT_DATASET}/${PRUNED_MODEL_FILE}_AllLabel

    CMD="torchrun \
      --nproc_per_node=${GPUS_PER_NODE} \
      --master_addr ${MASTER_ADDR} \
      --master_port ${MASTER_PORT} \
      /TO/MY/PATH/code/Understanding_Performance_Collapse/TALE/finetune_pruned_llm.py \
        --model_name_or_path \"${PRUNED_MODEL_PATH}\" \
        --sft_dataset ${SFT_DATASET} \
        --output_dir \"${OUT_DIR}\" \
        --deepspeed_config \"${DS_CONFIG}\" \
        --max_length 512 \
        --per_device_train_batch_size 2 \
        --gradient_accumulation_steps 1 \
        --num_train_epochs 3 \
        --learning_rate 5e-5 \
        --warmup_ratio 0.03 \
        --logging_steps 50 \
        --save_iter 1000 \
        --seed 42"

    # 真正执行
    eval ${CMD}
    echo "-------------------------------------------------------------"


done

echo ">>> All finetuning tasks completed!"
