nvidia-smi
nvcc --version

export HF_TOKEN=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

# offline training
# export HF_HUB_OFFLINE=1
# export TRANSFORMERS_OFFLINE=1
# export HF_DATASETS_OFFLINE=1

if [ "$HF_DATASETS_OFFLINE" = 1 ]; then
    echo "Warning: Offline mode is enabled. Using local copy of datasets"
    DATA_CONFIG_FILE="./data_configs/train_config_offline.yaml"
else
    DATA_CONFIG_FILE="./data_configs/train_config_vln.yaml"
    # DATA_CONFIG_FILE="./data_configs/train_config.yaml"
fi
if [ "$TRANSFORMERS_OFFLINE" = 1 ]; then
    echo "Warning: Offline mode is enabled. Using local copy of models"
    model_name_or_path="{local_model_path}"
else
    # model_name_or_path="TIGER-Lab/Mantis-8B-clip-llama3-pretraind"
    model_name_or_path="TIGER-Lab/Mantis-8B-siglip-llama3"
fi
if [ "$HF_HUB_OFFLINE" = 1 ]; then
    echo "Warning: Offline mode is enabled. Using local copy of model and datasets"
    push_to_hub=False
else
    push_to_hub=True
fi
if [ -z $HF_HOME ]; then
    echo "HF_HOME is empty, set to default '~/.cache/huggingface/'"
    export HF_HOME="~/.cache/huggingface/"
fi
if [ -z $HF_TOKEN ]; then
    echo "HF token is empty, try loading from '$HF_HOME/token'"
    export HF_TOKEN=$(eval "cat ${HF_HOME}/token")
fi
if [ -z $HF_TOKEN ]; then
    echo "HF token cannot be found, please set your HF token"
    exit 1
fi

hf_hub_user_name="" # set this will push the model to your hub after training
max_seq_len=8192
lora_enabled=true
DATA_FORMAT="chat"
OUTPUT_DIR="../../checkpoints"
global_batch_size=32
mllava_type="llava"

RUN_NAME="${mllava_type}_siglip_llama3_8b_vln_post_ft_train"
export WANDB_PROJECT="Mantis"
if [ $lora_enabled = true ]; then
    echo "lora is enabled"
    RUN_NAME="${RUN_NAME}_${max_seq_len}_lora"
else
    echo "lora is disabled"
    RUN_NAME="${RUN_NAME}_${max_seq_len}"
fi
echo "RUN_NAME = $RUN_NAME"

hub_model_id="${hf_hub_user_name}/${RUN_NAME}" # the hub model id
hub_token=$HF_TOKEN # set in .bashrc or replace with your own token
if [ -z $hf_hub_user_name ]; then
    echo "hf_hub_user_name is empty, do not push to hub"
    push_to_hub=False
else
    echo "hf_hub_user_name = $hf_hub_user_name"
fi

# resume from checkpoint
resume_from_checkpoint=""
if [ -d $resume_from_checkpoint ]; then
    echo "resume_from_checkpoint = $resume_from_checkpoint"
    export WANDB_LAST_RUN_ID="your_last_run_id"
else
    echo "No checkpoint found, training from scratch"
fi

export NCCL_DEBUG=INFO;
export CXX=g++;

export MASTER_ADDR=$MASTER_ADDR
export MASTER_PORT=$MASTER_PORT
export COUNT_NODE=$WORLD_SIZE

if [ -z $HOSTNAMES ]; then
    echo "HOSTNAMES is empty"
    export HOSTNAMES=$(hostname | awk '{print $1}')
fi
if [ -z $MASTER_ADDR ]; then
    echo "MASTER_ADDR is empty"
    export MASTER_ADDR=$(hostname -I | awk '{print $1}')
fi
if [ -z $MASTER_PORT ]; then
    echo "MASTER_PORT is empty"
    export MASTER_PORT=12956
fi
if [ -z $COUNT_NODE ]; then
    echo "COUNT_NODE is empty"
    export COUNT_NODE=1
fi
if [ -z $RANK ]; then
    echo "RANK is empty"
    export RANK=0
fi


NGPU_PER_NODE=$(nvidia-smi --query-gpu=index --format=csv,noheader | grep -c "$(echo $CUDA_VISIBLE_DEVICES | tr ',' '\n')")
GPU=$((${COUNT_NODE} * ${NGPU_PER_NODE}))
WORKERS=$((${COUNT_NODE} * ${NGPU_PER_NODE} * 4))

if [ $WORKERS -gt 112 ]; then
    WORKERS=112
fi

echo HOSTNAMES = $HOSTNAMES
echo MASTER_ADDR= $MASTER_ADDR
echo MASTER_PORT= $MASTER_PORT
echo COUNT_NODE= $COUNT_NODE
echo RANK= $RANK
echo GPU=${GPU}
echo WORKERS=$WORKERS
echo "Running ${RUN_NAME}"


if [ $lora_enabled = true ]; then
    echo "lora is enabled"
    config_file="./accelerate_configs/accelerate_config_zero2.yaml"
    echo $config_file
else
    echo "lora is disabled"
    config_file="./accelerate_configs/accelerate_config_zero3.yaml"
    echo $config_file
fi

per_device_train_batch_size=1
gradient_accumulation_steps=$(($global_batch_size / ($per_device_train_batch_size * $GPU)))
echo gradient_accumulation_steps=$global_batch_size / \($per_device_train_batch_size \* $GPU\) = $gradient_accumulation_steps

# sleep 12600;
accelerate launch --config_file=$config_file \
    --machine_rank $RANK \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --num_machines=${COUNT_NODE} --num_processes=${GPU} \
    train_mllava.py \
    --ddp_find_unused_parameters False \
    --model_name_or_path $model_name_or_path \
    --data_config_file $DATA_CONFIG_FILE \
    --data_format $DATA_FORMAT \
    --run_name $RUN_NAME \
    --bf16 True \
    --output_dir $OUTPUT_DIR \
    --hub_model_id $hub_model_id \
    --hub_token "$hub_token" \
    --push_to_hub $push_to_hub \
    --num_train_epochs 10 \
    --per_device_train_batch_size $per_device_train_batch_size \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps $gradient_accumulation_steps \
    --evaluation_strategy "no" \
    --save_strategy "steps" \
    --save_only_model \
    --save_steps 300 \
    --eval_steps 300 \
    --save_total_limit 50 \
    --learning_rate 1e-5 \
    --weight_decay 0.0 \
    --warmup_ratio 0.03 \
    --lr_scheduler_type cosine \
    --logging_steps 150 \
    --tf32 True \
    --gradient_checkpointing True \
    --dataloader_num_workers $WORKERS \
    --report_to tensorboard \
    --do_train \
    --lora_enabled $lora_enabled \
    --max_seq_len $max_seq_len \
    --tune_xatten_layer_only $tune_xatten_layer_only \
    --mllava_type $mllava_type \
    # --resume_from_checkpoint "$resume_from_checkpoint" \