#!/bin/bash

max_len=4096
sample_size=10

model=${1:-"vicuna"}
task=${2:-"nc"}
dataset=${3-"arxiv-products"}
bs=${4:-16}
emb=${5:-"simteg"}
num_adapter=${6:-2}


if [ ${model} = "vicuna-linear" ]; then
  use_hop=2
  template="ND"
  adapter_type="linear"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gcn" ]; then
  use_hop=2
  template="ND"
  adapter_type="gcn"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gin" ]; then
  use_hop=2
  template="ND"
  adapter_type="gin"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-gat" ]; then
  use_hop=2
  template="ND"
  adapter_type="gat"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "vicuna-sage" ]; then
  use_hop=2
  template="ND"
  adapter_type="sage"
  prefix=llaga-vicuna-7b-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=lmsys/vicuna-7b-v1.5-16k
  mode="v1"
elif [ ${model} = "llama3" ]; then
  use_hop=2
  template="ND"
  adapter_type="linear"
  prefix=llaga-llama-3-8B-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=meta-llama/Meta-Llama-3-8B
  mode="llaga_llama_3"
elif [ ${model} = "llama" ]; then
  use_hop=2
  template="ND"
  adapter_type="linear"
  prefix=llaga-llama-2-7b-hf-${emb}-${use_hop}-${sample_size}-${adapter_type}-projector
  model_base=meta-llama/Llama-2-7b-hf
  mode="llaga_llama_2"
fi


echo "PREFIX:  ${prefix}"

wandb offline
echo python  train/train_mem.py \
--model_name_or_path ${model_base} \
--version ${mode} \
--cache_dir  /data/haotian/.cache \
--pretrained_embedding_type ${emb} \
--tune_mm_mlp_adapter True \
--mm_use_graph_start_end False \
--mm_use_graph_patch_token False \
--bf16 True \
--output_dir  ./checkpoints/${dataset}/${prefix}_${task}_${seed} \
--num_train_epochs 1 \
--per_device_train_batch_size ${bs} \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length ${max_len} \
--gradient_checkpointing True \
--lazy_preprocess True \
--report_to wandb \
--adapter_type ${adapter_type} \
--num_adapter ${num_adapter}\
--use_task ${task} \
--use_dataset ${dataset} 

python  train/train_mem.py \
--model_name_or_path ${model_base} \
--version ${mode} \
--cache_dir  /data/haotian/.cache \
--pretrained_embedding_type ${emb} \
--tune_mm_mlp_adapter True \
--mm_use_graph_start_end False \
--mm_use_graph_patch_token False \
--bf16 True \
--output_dir  ./checkpoints/${dataset}/${prefix}_${task}_${seed} \
--num_train_epochs 1 \
--per_device_train_batch_size ${bs} \
--per_device_eval_batch_size 4 \
--gradient_accumulation_steps 1 \
--evaluation_strategy "no" \
--save_strategy "epoch" \
--learning_rate 2e-3 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--tf32 True \
--model_max_length ${max_len} \
--gradient_checkpointing True \
--lazy_preprocess True \
--report_to wandb \
--adapter_type ${adapter_type} \
--num_adapter ${num_adapter}\
--use_task ${task} \
--use_dataset ${dataset} 