#!/bin/bash

MODEL_NAME_OR_PATH="meta-llama/Llama-3.1-8B"
TRAIN_FILE_PATH="example_train_data.jsonl"
EVAL_FILE_PATH="example_val_data.jsonl"
EPOCHS=2
BATCH_SIZE=1
LR=1e-5
PROJECT_NAME="RUN--r_sft--llama3.1_8b"
RUN_NAME="offline_data--pre_bs8--bs_${BATCH_SIZE}--epoch_${EPOCHS}--lr_${LR}--nnodes_${SLURM_NNODES}--ngpus_${SLURM_GPUS_PER_NODE}--$(date +%Y%m%d_%H%M%S)"
OUTPUT_DIR=./checkpoints/${PROJECT_NAME}/${RUN_NAME}

torchrun \
    --nnodes $SLURM_NNODES \
    --nproc_per_node ${SLURM_GPUS_PER_NODE:-1} \
    --rdzv_id $SLURM_JOB_ID \
    --rdzv_backend c10d \
    --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
    train_VAR.py --algorithm "reweighted_sft" \
    --model_name_or_path $MODEL_NAME_OR_PATH \
    --train_file_path $TRAIN_FILE_PATH \
    --eval_file_path $EVAL_FILE_PATH \
    --output_dir $OUTPUT_DIR \
    --report_to "tensorboard,wandb" \
    --project_name $PROJECT_NAME \
    --run_name $RUN_NAME \
    --num_train_epochs $EPOCHS \
    --per_device_train_batch_size $BATCH_SIZE \
    --per_device_eval_batch_size 8 \
    --reward_batch_size 8 \
    --gradient_accumulation_steps 1 \
    --logging_steps 5 \
    --learning_rate $LR \
    --eval_first True \
    --pre_defined_B 8 \
    --use_reward_api \
    --deepspeed_stage 2 \
    --max_prompt_length 512 \
    --max_length 256 \
    --save_last \
    --use_sys_prompt \
    --system_prompt "You are a pirate chatbot who always responds in pirate speak!"
