#!/bin/bash

modified_dropout_pattern=$1
modified_dropout_rate=$2
modified_aug_loss=$3
modified_aug_loss_weight=$4
GPU_ID=$5
project_root=<project_root>
python_path=<python_path>

cd ${project_root}/exps/NLG || exit
export PYTHONPATH=${project_root}:$PYTHONPATH
export PYTHONPATH=${project_root}/exps/NLG:$PYTHONPATH

ports=(11176 12345 13035 13502 14554 15173 15253 15445 16683 17893 18075 18525 19211 20294 21202 21975 22304 22948 24202 24570 24978 26423 28321 29763) # count: 24

seed_min=0
seed_max=2
for ((seed = seed_min; seed <= seed_max; seed++)); do
  TIME=$(date "+%Y%m%d-%H%M%S")
  run_name="GPT2_M_e2e_${TIME}_sd_${seed}_GPU_${GPU_ID}_dp_${modified_dropout_pattern}_dr_${modified_dropout_rate}_${modified_aug_loss}_${modified_aug_loss_weight}" && runs+=("${run_name}")
  logging_dir=./logs/${run_name} && mkdir -p ${logging_dir}
  log_pth=${logging_dir}/log.txt && touch ${log_pth}
  output_dir=./checkpoints/${run_name} && mkdir -p ${output_dir}

  # 1. Train GPT-2 Medium with LoRA (see our paper for hyperparameters for GPT-2 Medium)
  echo "$(date "+%Y%m%d-%H%M%S"): Start training GPT-2 Medium with LoRA on E2E dataset..." >>${log_pth}
  port=${ports[$((GPU_ID * 6 + seed * 2 + 0))]}
  CUDA_VISIBLE_DEVICES=${GPU_ID} \
    ${python_path} -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} src/gpt2_ft.py \
    --random_seed ${seed} \
    --modified_aug_loss=${modified_aug_loss} \
    --modified_aug_loss_weight=${modified_aug_loss_weight} \
    --modified_dropout_pattern ${modified_dropout_pattern} \
    --modified_dropout_rate ${modified_dropout_rate} \
    --modified_disable_ori_dropout \
    --train_data ./data/e2e/train.jsonl \
    --valid_data ./data/e2e/valid.jsonl \
    --train_batch_size 8 \
    --grad_acc 1 \
    --valid_batch_size 4 \
    --seq_len 512 \
    --model_card gpt2.md \
    --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \
    --platform local \
    --clip 0.0 \
    --lr 0.0002 \
    --weight_decay 0.01 \
    --correct_bias \
    --adam_beta2 0.999 \
    --scheduler linear \
    --warmup_step 500 \
    --max_epoch 5 \
    --save_interval 2000 \
    --eval_interval 2000 \
    --lora_dim 4 \
    --lora_alpha 32 \
    --lora_dropout 0.0 \
    --label_smooth 0.1 \
    --work_dir ${output_dir} |
    tee -a ${log_pth}
  echo "$(date "+%Y%m%d-%H%M%S"): Finish training GPT-2 Medium with LoRA on E2E dataset..." >>${log_pth}

  (
    # 2. Generate outputs from the trained model using beam search:
    output_idx_file=${logging_dir}/predict.26290.jsonl

    echo "$(date "+%Y%m%d-%H%M%S"): Start generating outputs from the trained model using beam search..." >>${log_pth}
    port=${ports[$((GPU_ID * 6 + seed * 2 + 1))]}
    CUDA_VISIBLE_DEVICES=${GPU_ID} \
      ${python_path} -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} src/gpt2_beam.py \
      --data ./data/e2e/test.jsonl \
      --batch_size 1 \
      --seq_len 512 \
      --eval_len 64 \
      --model_card gpt2.md \
      --init_checkpoint ${output_dir}/model.26290.pt \
      --platform local \
      --lora_dim 4 \
      --lora_alpha 32 \
      --beam 10 \
      --length_penalty 0.9 \
      --no_repeat_ngram_size 4 \
      --repetition_penalty 1.0 \
      --eos_token_id 628 \
      --work_dir ${output_dir} \
      --output_file ./../../${output_idx_file} |
      tee -a ${log_pth}
    echo "$(date "+%Y%m%d-%H%M%S"): Finish generating outputs from the trained model using beam search..." >>${log_pth}

    # 3. Decode outputs from step (2)
    output_ref_file=${logging_dir}/e2e_ref.txt
    output_pred_file=${logging_dir}/e2e_pred.txt

    echo "$(date "+%Y%m%d-%H%M%S"): Start decoding outputs from step (2)..." >>${log_pth}
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      ${python_path} src/gpt2_decode.py \
      --vocab ./vocab \
      --sample_file ${output_idx_file} \
      --input_file ./data/e2e/test_formatted.jsonl \
      --output_ref_file ${output_ref_file} \
      --output_pred_file ${output_pred_file} |
      tee -a ${log_pth}
    echo "$(date "+%Y%m%d-%H%M%S"): Finish decoding outputs from step (2)..." >>${log_pth}

    # 4. Run evaluation on E2E test set
    echo "$(date "+%Y%m%d-%H%M%S"): Start running evaluation on E2E test set..." >>${log_pth}
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      ${python_path} eval/e2e/measure_scores.py \
      ${output_ref_file} \
      ${output_pred_file} \
      -p |
      tee -a ${log_pth}
    echo "$(date "+%Y%m%d-%H%M%S"): Finish running evaluation on E2E test set..." >>${log_pth}

  ) || true

done
