date

cd YOUR_ROOT_PATH/MLLM

MASTER_ADDR=gpu03
MASTER_PORT=30098
NNODES=4
GPUS_PER_NODE=8
NUM_PROCESSES=$(expr $NNODES \* $GPUS_PER_NODE)

echo "MASTER_ADDR: $MASTER_ADDR, MASTER_PORT: $MASTER_PORT, NNODES: $NNODES, GPUS_PER_NODE: $GPUS_PER_NODE, NUM_PROCESSES: $NUM_PROCESSES"

RUN_NAME="scalability_OI+IC_all_with_JD_factorized_mlp"
# LOG_PATH="/dev/null"

if [ $HOSTNAME = 'gpu03' ]
then
    MACHINE_RANK=0
    LOG_PATH="YOUR_ROOT_PATH/MLLM/logs/scalability/$RUN_NAME.log"
elif [ $HOSTNAME = 'gpu04' ]
then
    MACHINE_RANK=1
    LOG_PATH="YOUR_ROOT_PATH/MLLM/logs/gpu04.log"
elif [ $HOSTNAME = 'gpu05' ]
then
    MACHINE_RANK=2
    LOG_PATH="YOUR_ROOT_PATH/MLLM/logs/gpu05.log"
elif [ $HOSTNAME = 'gpu06' ]
then
    MACHINE_RANK=3
    LOG_PATH="YOUR_ROOT_PATH/MLLM/logs/gpu06.log"
elif [ $HOSTNAME = 'gpu07' ]
then
    MACHINE_RANK=4
fi

echo "MACHINE_RANK: $MACHINE_RANK"

echo "RUN_NAME: $RUN_NAME, LOG_PATH: $LOG_PATH"

accelerate launch --config_file configs/zero2_bf16_custom.yaml \
  --main_process_ip $MASTER_ADDR \
  --main_process_port $MASTER_PORT \
  --num_machines $NNODES \
  --num_processes $NUM_PROCESSES \
  --machine_rank $MACHINE_RANK \
  src/pre_train.py \
  --use_slow_tokenizer \
  --use_xformers \
  --use_custom_attention_mask \
  --low_cpu_mem_usage \
  --with_tracking \
  --output_dir "YOUR_ROOT_PATH/model/checkpoint/MLLM/$RUN_NAME" \
  --gradient_accumulation_steps 8 \
  --per_device_train_batch_size 2 \
  --per_device_eval_batch_size 12 \
  --num_train_epochs 1 \
  --num_warmup_steps 0.1 \
  --custom_lr_scheduler \
  --min_lr_ratio 0.1 \
  --eval_frequency 1 \
  --learning_rate 0.0002 \
  --loss_scale_visual 0.1 \
  --use_lora \
  --lora_rank 64 \
  --lora_alpha 128 \
  --lora_target_modules "q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" \
  --pretrain_dataset_name "oi+ic" \
  --ic_jd \
  --uni_and_bi \
  --uni_image_prob 0.9 \
  --uni_text_prob 0.9 \
  --oi_time 1 \
  --expand_vocab "factorized" \
  --factorized_linear_mlp \
  --compress_batch \
  --run_name "$RUN_NAME" \
  > $LOG_PATH 2>&1

date