#!/bin/bash
echo $# arguments

if [ "$#" -ne 7 ]; then
  echo "Arguments should be <MODEL_NAME> <TASK> <EXPERIMENT_NAME> <BATCH_SIZE> <epochs> <N-GPUS> <SEQ-LEN>"
  exit 2
fi

export num_gpus=$6
export CUBLAS_WORKSPACE_CONFIG=":16:8" # https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
export PYTHONHASHSEED=0
export base_log_dir="./"
export base_dir="/system/user/publicdata"
export cache_dir="${base_dir}/llm"
export data_cache_dir=$base_dir
export WANDB_PROJECT="EVA"
export WANDB_ENTITY="ml_eva"
model_name=$1
task_name=$2
experiment_name=$3
batch_size=$4
epochs=$5
seq_len=$7
high_resource=("qqp" "qnli" "mnli" "sst2")

# private port range
min_port=49152
max_port=65535
random_port=$(($min_port + $RANDOM % ($max_port - $min_port + 1)))

if [ $model_name == "roberta-large" ]
then
  learning_rates=(1e-3 4e-4 1e-4)
else
  learning_rates=(4e-3 1e-3 4e-4)
fi

if [[ ${high_resource[@]} =~ $task_name ]]
then
  seeds=(0 10 101)
else
  seeds=(0 10 101 1010 10101)
fi

for seed in ${seeds[@]}
do
  for r in 2 4 8 16
  do
    for lr in ${learning_rates[@]}
    do
      dir="r_${r}_lr_${lr}"
      output_dir=$base_log_dir/$experiment_name/$task_name/$dir/$seed

      if [ ! -d $output_dir ]
      then
        echo "Logging to ${output_dir}"
        mkdir -p $output_dir/log
        export WANDB_DIR=$output_dir/log

        python -m torch.distributed.launch --nproc_per_node=$num_gpus --master_port $random_port \
        examples/text-classification/run_glue.py \
        --model_name_or_path $model_name \
        --cache_dir $cache_dir \
        --data_cache_dir $data_cache_dir \
        --lora_path $cache_dir/${model_name}_${task_name}_r_16_pca.bin \
        --task_name $task_name \
        --do_train \
        --fp16 \
        --do_eval \
        --max_seq_length $seq_len \
        --per_device_train_batch_size $batch_size \
        --learning_rate $lr \
        --num_train_epochs $epochs \
        --output_dir $output_dir/model \
        --overwrite_output_dir \
        --logging_steps 10 \
        --logging_dir $output_dir/log \
        --evaluation_strategy epoch \
        --save_strategy no \
        --warmup_ratio 0.06 \
        --apply_lora \
        --lora_r $r \
        --lora_alpha 1 \
        --lora_kind LoRA \
        --adaptive_ranks \
        --redistribute_from_scratch \
        --redist_metric raw \
        --seed $seed \
        --weight_decay 0.1 \
        --report_to=wandb \
        --experiment_name=$experiment_name
      fi
    done
  done
done
