#!/bin/bash

export MAX_JOBS=20

export LD_LIBRARY_PATH=/data1/tianlong/anaconda3/envs/bio_llava/lib:$LD_LIBRARY_PATH
export PATH="/data1/tianlong/anaconda3/bin:$PATH"
export PATH="/data1/tianlong/anaconda3/envs/bio_llava/bin:$PATH"
export TRANSFORMERS_CACHE=/data1/tianlong/cache/
export HOME=/data1/tianlong

# export LD_LIBRARY_PATH=/data1/tianlong/anaconda3/envs/txplm/lib:$LD_LIBRARY_PATH
# export PATH="/data1/tianlong/anaconda3/envs/txplm/lib:$PATH"
deepspeed_p=/data1/tianlong/anaconda3/envs/bio_llava/bin/deepspeed
export TOKEN_PATH=/data1/tianlong/LLaVA_ckpt/vicuna-13b-v1.5


gpus="1,2,3,4,5,6,7,8"
# gpus="0,1"

lr="5e-5"
pretrain_out_dir=./checkpoints/bio-llava-pretrain-stage-one-vicuna-mmseq-smiles-13b

args=" 
    --deepspeed ./scripts/zero2.json 
    --model_name_or_path /data1/tianlong/LLaVA_ckpt/vicuna-13b-v1.5
    --version plain 
    --is_multimodal True
    --go_term_graph datasets/GOA_Human/go.obo 
    --protein_pkl datasets/GOA_Human/train_data_fold_0.pkl 
    --protein_tower /data1/tianlong/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_protein_tower /data1/tianlong/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_projector_type mlp2x_gelu 
    --tune_mm_mlp_adapter True 
    --mm_protein_select_layer -1 
    --mm_use_prot_start_end False 
    --mm_use_prot_patch_token False 
    --bf16 True 
    --output_dir $pretrain_out_dir
    --num_train_epochs 5 
    --per_device_train_batch_size 16 
    --per_device_eval_batch_size 16 
    --gradient_accumulation_steps 1 
    --evaluation_strategy "steps" 
    --eval_steps 50 
    --save_strategy "epoch" 
    --save_total_limit 1 
    --learning_rate $lr
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --tf32 True 
    --model_max_length 2048 
    --gradient_checkpointing True 
    --dataloader_num_workers 4 
    --lazy_preprocess False 
    --report_to tensorboard
    --with_comments False
    --training_csv_file datasets/split70.csv
    --retrieval_mmseq True
    --retrieval_smiles True
"
CUDA_LAUNCH_BLOCKING=1 NCCL_P2P_DISABLE=1 $deepspeed_p --include localhost:$gpus --master_port=29601 llava/train/train_mem_bio.py ${args}

# sleep 10


output_dir=./checkpoints/bio-llava-fine-tune-lora-vicuna-mmseq-smiles-13b
pretrained_mm_adapter=$pretrain_out_dir/mm_projector.bin

args="
    --lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5
    --deepspeed ./scripts/zero2.json 
    --model_name_or_path /data1/xxx/LLaVA_ckpt/vicuna-13b-v1.5
    --version v1_biollava 
    --is_multimodal True
    --go_term_graph datasets/GOA_Human/go.obo 
    --protein_pkl datasets/GOA_Human/train_data_fold_0.pkl 
    --protein_tower /data1/xxx/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_protein_tower /data1/xxx/LLaVA_ckpt/esm2_t33_650M_UR50D
    --mm_projector_type mlp2x_gelu 
    --mm_protein_select_layer -1 
    --mm_use_prot_start_end False 
    --mm_use_prot_patch_token False 
    --bf16 True 
    --output_dir $output_dir
    --num_train_epochs 10
    --per_device_train_batch_size 16 
    --per_device_eval_batch_size 16
    --gradient_accumulation_steps 1 
    --evaluation_strategy "steps" 
    --eval_steps 100
    --save_strategy "epoch" 
    --save_total_limit 1 
    --learning_rate $lr
    --weight_decay 0. 
    --warmup_ratio 0.03 
    --lr_scheduler_type "cosine" 
    --logging_steps 1 
    --tf32 True 
    --model_max_length 2048 
    --gradient_checkpointing True 
    --dataloader_num_workers 4 
    --lazy_preprocess True 
    --pretrain_mm_mlp_adapter $pretrained_mm_adapter
    --tune_norm_layer False
    --report_to tensorboard
    --with_comments False
    --training_csv_file datasets/split70.csv
    --retrieval_mmseq True
    --retrieval_smiles True
"
NCCL_P2P_DISABLE=1 $deepspeed_p --include localhost:$gpus --master_port=29601 llava/train/train_mem_bio.py ${args}