
SFT_DATASET="bird_train_full_roberta"

LABELED_FILE="data/labeled/bird_train_pipeline_label.jsonl"

SFT_CONFIG="roberta_config"  
TRAINING_MODE="lora"

python -m src.sft.prepare_classifier_sft_data \
    --sft_dataset ${SFT_DATASET} \
    --labeled_file ${LABELED_FILE}

export NCCL_P2P_DISABLE=1
export NCCL_IB_DISABLE=1

export CUDA_VISIBLE_DEVICES=4,5,6,7
torchrun --nproc_per_node=4 --master_port=29500 \
    -m src.sft.roberta_classifier_sft \
    --sft_config ${SFT_CONFIG} \
    --sft_dataset ${SFT_DATASET} \
    --training_mode ${TRAINING_MODE}

echo -e "\n\n----------------- Inference example of RoBERTa Classifier SFT -----------------\n"
python -m src.sft.roberta_classifier_inference
