#!/bin/bash

export checkpoints="."
export TMPDIR=""
export TOOLKIT_DIR="."
export KEY_HF=$(cat huggingface_key.secret)
export ID_GPUS="0,1,2,3"
export MAX_STEPS=-1 
export MODELDIR="phi35-siglip224"
export PYTHONPATH="./LibMoE":$PYTHONPATH
export TMUX_TMPDIR=~/tmux_tmp
export WANDB_ENTITY=""
export WANDB_PROJECT=""

# Graph routing settings
setting=$1
alpha=$2  
beta=$3
norm_type=$4
nwarmup_steps=$5

if [ $setting -eq 1 ]; then
   gate_sym=true
   jobname="graphglobal-m-$setting-alpha$alpha-beta$beta-sym-$norm_type-nwarmup$nwarmup_steps-665k"
elif [ $setting -eq 2 ]; then
   gate_sym=false 
   jobname="graphglobal-m-$setting-alpha$alpha-beta$beta-asym-$norm_type-nwarmup$nwarmup_steps-665k"
fi

deepspeed --master_port 29520 --include localhost:$ID_GPUS moe_model/train/train_mem.py \
   --deepspeed ./scripts/zero3.json \
   --model_name_or_path $checkpoints/checkpoints/$MODELDIR/pft \
   --version phi35 \
   --data_path $checkpoints/data/jsons/llava_v1_5_mix665k.json \
   --image_folder $checkpoints/data \
   --vision_tower google/siglip-so400m-patch14-224 \
   --vision_tower_dir $checkpoints/checkpoints/$MODELDIR/pft/clip.bin \
   --scales 1,3 \
   --pretrain_mm_mlp_adapter $checkpoints/checkpoints/$MODELDIR/pft/mm_projector.bin \
   --mm_projector_type moe \
   --mlp_smoe true \
   --clip_smoe true \
   --gate_sym $gate_sym \
   --gate_norm_type $norm_type \
   --gate_alpha $alpha \
   --gate_beta $beta \
   --gate_softmax_temp 1.0 \
   --gate_warmup_nsteps $nwarmup_steps \
   --moe_name "smoe_graphgating" \
   --num_experts 4 \
   --num_selected 2 \
   --sparse_upcycling true \
   --balance_loss_coef 0.01 \
   --router_z_loss_coef 0.001 \
   --mm_vision_select_layer -2 \
   --mm_use_im_start_end False \
   --mm_use_im_patch_token False \
   --image_aspect_ratio pad \
   --group_by_modality_length True \
   --bf16 True \
   --output_dir $checkpoints/checkpoints/$MODELDIR/sft/$jobname \
   --num_train_epochs 1 \
   --per_device_train_batch_size 5 \
   --per_device_eval_batch_size 1 \
   --gradient_accumulation_steps 2 \
   --evaluation_strategy "no" \
   --save_strategy "steps" \
   --save_steps 1664 \
   --save_total_limit 16 \
   --learning_rate 4e-6 \
   --weight_decay 0. \
   --warmup_ratio 0.03 \
   --lr_scheduler_type "cosine" \
   --logging_steps 1 \
   --tf32 True \
   --model_max_length 2048 \
   --gradient_checkpointing True \
   --dataloader_num_workers 4 \
   --lazy_preprocess True \
   --report_to wandb \
   --run_name $jobname \
   --max_steps $MAX_STEPS
