#!/bin/bash
export CUBLAS_WORKSPACE_CONFIG=":16:8" # https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
export PYTHONHASHSEED=0
export base_dir="/system/user/publicwork/hauzenbe/eva_cache"
export WANDB_PROJECT="EVA"
export WANDB_ENTITY="ml_eva"
export CUDA_VISIBLE_DEVICES="0"

# private port range
min_port=49152 # >= 49152
max_port=65535 # <= 65535

proc_per_gpu=$1
model_card='gpt2.md'
model_name='gpt2-medium'
task_name='e2e_nlg'
experiment_name='pca_init_adaptive'
batch_size=8
epochs=5
seq_len=512
# array params
seeds=(0 10 101) # 0 10 101
lora_dims=(16 8 4 2) # 16 8 4 2
learning_rates=(1e-3 4e-4 2e-4) # 1e-3 4e-4 2e-4

# model_checkpoint
model_checkpoint="$model_name-pytorch_model.bin"

# create chunks from array params
chunks=()
chunk=()
for a in ${seeds[@]}
do
  for b in ${lora_dims[@]}
  do
    for c in ${learning_rates[@]}
    do
      chunk+=("$a|$b|$c")
      # add chunk to chunks once its equal to n_parallel
      if [[ ${#chunk[@]} -eq $proc_per_gpu ]]; then
        chunk_string=$( IFS="%"; echo "${chunk[*]}")
        chunks+=($chunk_string)
        chunk=()
      fi
    done
  done
done

# Add the final chunk (if not empty)
if [[ ${#chunk[@]} -gt 0 ]]; then 
  chunk_string=$( IFS="%"; echo "${chunk[*]}")
  chunks+=($chunk_string)
fi


# start scripts in parallel
for chunk in ${chunks[@]}
do

  # find free ports
  unused_ports=()
  for port in $(seq $min_port $max_port); do
    if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null ; then
      unused_ports+=($port)  # Add to array if port is unused
      if [[ ${#unused_ports[@]} -eq $proc_per_gpu ]]; then
        break  # Exit the loop once we've found 'proc_per_gpu' ports
      fi
    fi
  done

  parallel_combs=($(echo $chunk | tr "%" " "))
  #for params in ${parallel_combs[@]}
  for (( i=0; i<${#parallel_combs[@]}; i++ ))
  do
    params_array=($(echo ${parallel_combs[i]} | tr "|" " "))
    seed=${params_array[0]}
    lora_dim=${params_array[1]}
    lr=${params_array[2]}

    dir="r_${lora_dim}_lr_${lr}"
    output_dir=$base_dir/$model_card/$task_name/$experiment_name/$dir/$seed

    if [ ! -d $output_dir ]
    then
      echo "Logging to ${output_dir}"
      mkdir -p $output_dir/log
      export WANDB_DIR=$output_dir/log

      python -m torch.distributed.launch --nproc_per_node=1 --master_port ${unused_ports[i]} \
      src/gpt2_ft.py \
      --platform local \
      --train_data ./data/e2e/train.jsonl \
      --valid_data ./data/e2e/valid.jsonl \
      --train_batch_size $batch_size \
      --grad_acc 1 \
      --valid_batch_size $batch_size \
      --seq_len $seq_len \
      --model_card $model_card \
      --init_checkpoint ./pretrained_checkpoints/$model_checkpoint \
      --clip 0.0 \
      --lr $lr \
      --weight_decay 0.01 \
      --correct_bias \
      --adam_beta2 0.999 \
      --scheduler linear \
      --warmup_step 500 \
      --max_epoch $epochs \
      --eval_interval 1000 \
      --lora_dim $lora_dim \
      --lora_alpha 32 \
      --lora_dropout 0.0 \
      --label_smooth 0.1 \
      --work_dir $output_dir \
      --random_seed $seed \
      --lora_path $base_dir/${model_name}_e2e_nlg_r_32_pca.bin \
      --adaptive_ranks &
    fi
  done
  wait
done