#!/bin/bash

max_len=4096
sample_size=4

model=${1:-"vicuna"}
task=${2:-"lp"}
dataset=${3-"arxiv-products"}
bs=${4:-16}
emb=${5:-"simteg"}
num_adapter=${6:-2}
use_hop=${7:-2}
seed=${8:-0}
graph_transform=${9:-None}

if [ ${model} = "vicuna-linear" ]; then
  template="anti-ND"
  adapter_type="linear"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gcn" ]; then
  template="anti-ND"
  adapter_type="gcn"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gin" ]; then
  template="anti-ND"
  adapter_type="gin"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gin-desc" ]; then
  template="ND"
  adapter_type="gin"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector-atom-desc
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
  EXTRA="--use_description"
elif [ ${model} = "vicuna-gat" ]; then
  template="anti-ND"
  adapter_type="gat"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-sage" ]; then
  template="anti-ND"
  adapter_type="sage"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "llama3" ]; then
  use_hop=2
  template="ND"
  adapter_type="linear"
  prefix=llaga-llama-3-8B-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=meta-llama/Meta-Llama-3-8B
  mode="llaga_llama_3"
elif [ ${model} = "llama" ]; then
  use_hop=2
  template="ND"
  adapter_type="linear"
  prefix=llaga-llama-2-7b-hf-${emb}-${use_hop}-${adapter_type}-${num_adapter}-projector
  model_base=meta-llama/Llama-2-7b-hf
  mode="llaga_llama_2"
fi

if [ ${graph_transform} = 'gdc' ]; then
  prefix=${prefix}-${graph_transform}
elif [ ${graph_transform} = 'sdc' ]; then
  prefix=${prefix}-${graph_transform}
elif [ ${graph_transform} = 'sgdc' ]; then
  prefix=${prefix}-${graph_transform}
elif [ ${graph_transform} = 'attn' ]; then
  prefix=${prefix}-${graph_transform}
elif [ ${graph_transform} = 'biattn' ]; then
  prefix=${prefix}-${graph_transform}
fi

echo "PREFIX:  ${prefix}"

wandb offline
echo deepspeed  --include localhost:2,3,4,5,6,7 --master_port 61000  train/train_mem_pyg.py \
  --deepspeed ./scripts/zero2.json \
  --model_name_or_path ${model_base} \
  --version ${mode} \
  --cache_dir  /data/haotian/.cache \
  --pretrained_embedding_type ${emb} \
  --tune_mm_mlp_adapter True \
  --mm_use_graph_start_end False \
  --mm_use_graph_patch_token False \
  --bf16 True \
  --output_dir  ./checkpoints/${dataset}/${prefix}_${task}_seed_${seed} \
  --num_train_epochs 1 \
  --per_device_train_batch_size ${bs} \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 1 \
  --evaluation_strategy "no" \
  --save_strategy "epoch" \
  --learning_rate 1e-4 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --tf32 True \
  --model_max_length ${max_len} \
  --gradient_checkpointing True \
  --lazy_preprocess True \
  --report_to wandb \
  --adapter_type ${adapter_type} \
  --num_adapter ${num_adapter} \
  --graph_transform ${graph_transform} \
  --use_task ${task} \
  --use_hop ${use_hop} \
  --use_dataset ${dataset} \
  --use_seed ${seed} \
  $EXTRA

# export NCCL_DEBUG=INFO
# export NCCL_DEBUG_SUBSYS=ALL
# export PYTHONFAULTHANDLER=1
# export CUDA_LAUNCH_BLOCKING=1
# export TORCH_DISTRIBUTED_DEBUG=DETAIL
deepspeed  --include localhost:2,3,4,5,6,7 --master_port 61000  train/train_mem_pyg.py \
  --deepspeed ./scripts/zero2.json \
  --model_name_or_path ${model_base} \
  --version ${mode} \
  --cache_dir  /data/haotian/.cache \
  --pretrained_embedding_type ${emb} \
  --tune_mm_mlp_adapter True \
  --mm_use_graph_start_end False \
  --mm_use_graph_patch_token False \
  --bf16 True \
  --output_dir  ./checkpoints/${dataset}/${prefix}_${task}_seed_${seed} \
  --num_train_epochs 1 \
  --per_device_train_batch_size ${bs} \
  --per_device_eval_batch_size 4 \
  --gradient_accumulation_steps 1 \
  --evaluation_strategy "no" \
  --save_strategy "epoch" \
  --learning_rate 1e-4 \
  --weight_decay 0. \
  --warmup_ratio 0.03 \
  --lr_scheduler_type "cosine" \
  --logging_steps 1 \
  --tf32 True \
  --model_max_length ${max_len} \
  --gradient_checkpointing True \
  --lazy_preprocess True \
  --report_to wandb \
  --adapter_type ${adapter_type} \
  --num_adapter ${num_adapter} \
  --graph_transform ${graph_transform} \
  --use_task ${task} \
  --use_hop ${use_hop} \
  --use_dataset ${dataset} \
  --use_seed ${seed} \
  $EXTRA