#!/bin/bash
export CUBLAS_WORKSPACE_CONFIG=":16:8" # https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
export PYTHONHASHSEED=0
export base_dir="/system/user/publicwork/hauzenbe/eva_cache"
export WANDB_PROJECT="EVA"
export WANDB_ENTITY="ml_eva"
export CUDA_VISIBLE_DEVICES="0"

# private port range
min_port=49152 # >= 49152
max_port=65535 # <= 65535

proc_per_gpu=$1
model_card='gpt2.md'
task_name='e2e_nlg'
experiment_name='rand_init'
epochs=5
seq_len=512
beam=10
batch_size=1
# array params
seeds=(0 10 101) # 0 10 101
lora_dims=(16 8 4 2) # 16 8 4 2
learning_rates=(1e-3 4e-4 2e-4) # 1e-3 4e-4 2e-4

# create chunks from array params
chunks=()
chunk=()
for a in ${seeds[@]}
do
  for b in ${lora_dims[@]}
  do
    for c in ${learning_rates[@]}
    do
      chunk+=("$a|$b|$c")
      # add chunk to chunks once its equal to n_parallel
      if [[ ${#chunk[@]} -eq $proc_per_gpu ]]; then
        chunk_string=$( IFS="%"; echo "${chunk[*]}")
        chunks+=($chunk_string)
        chunk=()
      fi
    done
  done
done

# Add the final chunk (if not empty)
if [[ ${#chunk[@]} -gt 0 ]]; then 
  chunk_string=$( IFS="%"; echo "${chunk[*]}")
  chunks+=($chunk_string)
fi

echo "${chunks[@]}"

# start scripts in parallel
for chunk in ${chunks[@]}
do

  # find free ports
  unused_ports=()
  for port in $(seq $min_port $max_port); do
    if ! lsof -Pi :$port -sTCP:LISTEN -t >/dev/null ; then
      unused_ports+=($port)  # Add to array if port is unused
      if [[ ${#unused_ports[@]} -eq $proc_per_gpu ]]; then
        break  # Exit the loop once we've found 'proc_per_gpu' ports
      fi
    fi
  done

  parallel_combs=($(echo $chunk | tr "%" " "))
  #for params in ${parallel_combs[@]}
  for (( i=0; i<${#parallel_combs[@]}; i++ ))
  do
    params_array=($(echo ${parallel_combs[i]} | tr "|" " "))
    seed=${params_array[0]}
    lora_dim=${params_array[1]}
    lr=${params_array[2]}

    dir="r_${lora_dim}_lr_${lr}"
    output_dir=$base_dir/$model_card/$task_name/$experiment_name/$dir/$seed

    python -m torch.distributed.launch --nproc_per_node=1 --master_port ${unused_ports[i]} \
    src/gpt2_beam.py \
    --platform local \
    --data ./data/e2e/test.jsonl \
    --batch_size $batch_size \
    --seq_len $seq_len \
    --eval_len 64 \
    --model_card $model_card \
    --init_checkpoint $output_dir/best_model.pt \
    --lora_dim $lora_dim \
    --lora_alpha 1 \
    --beam $beam \
    --length_penalty 0.9 \
    --no_repeat_ngram_size 4 \
    --repetition_penalty 1.0 \
    --eos_token_id 628 \
    --work_dir $output_dir \
    --random_seed $seed \
    --output_file beam_prediction.jsonl &
  done
  wait
done