#!/bin/bash

export MAX_JOBS=20

gpus="0"
# gpus="0,1"

lr="5e-5"
pretrain_out_dir=./checkpoints/bio-llava-pretrain-stage-one-vicuna-mmseq-smiles-zzz

output_dir=./checkpoints/bio-llava-fine-tune-lora-vicuna-mmseq-smiles-zzz
pretrained_mm_adapter=$pretrain_out_dir/mm_projector.bin

args="
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5
    --model_name_or_path /data1/xxx/LLaVA_ckpt/vicuna-7b-v1.5
    --version v1_biollava 
    --is_multimodal True
    --go_term_graph datasets/GOA_Human/go.obo 
    --protein_pkl datasets/GOA_Human/train_data_fold_0.pkl 
    --protein_tower /data1/xxx/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_protein_tower /data1/xxx/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_projector_type mlp2x_gelu 
    --mm_protein_select_layer -1 
    --mm_use_prot_start_end False 
    --mm_use_prot_patch_token False 
    --bf16 True 
    --output_dir $output_dir
    --num_train_epochs 10
    --per_device_train_batch_size 16 
    --per_device_eval_batch_size 16
    --gradient_accumulation_steps 1 
    --evaluation_strategy "steps" 
    --eval_steps 100
    --save_strategy "epoch" 
    --save_total_limit 1 
    --learning_rate $lr
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --tf32 True 
    --model_max_length 2048 
    --gradient_checkpointing True 
    --dataloader_num_workers 4 
    --lazy_preprocess True 
    --pretrain_mm_mlp_adapter $pretrained_mm_adapter
    --tune_norm_layer False
    --report_to tensorboard
    --with_comments False
    --training_csv_file datasets/split70.csv
    --retrieval_mmseq True
    --retrieval_smiles True
    --is_test True
"
NCCL_P2P_DISABLE=1 CUDA_VISIBLE_DEVICES=$gpus $python_p llava/train/train_mem_bio.py ${args}