#!/bin/bash

modified_dropout_pattern=$1
modified_dropout_rate=$2
modified_aug_loss=$3
modified_aug_loss_weight=$4
GPU_ID=$5

project_root=<project_root>
python_path=<python_path>

cd ${project_root}/exps/NLG || exit
export PYTHONPATH=${project_root}:$PYTHONPATH
export PYTHONPATH=${project_root}/exps/NLG:$PYTHONPATH

ports=(11176 12345 13035 13502 14554 15173 15253 15445 16683 17893 18075 18525 19211 20294 21202 21975 22304 22948 24202 24570 24978 26423 28321 29763) # count: 24

seed_min=0
seed_max=2
for ((seed = seed_min; seed <= seed_max; seed++)); do

  cd ${project_root}/exps/NLG || exit

  TIME=$(date "+%Y%m%d-%H%M%S")
  run_name="GPT2_M_webnlg_${TIME}_sd_${seed}_GPU_${GPU_ID}_dp_${modified_dropout_pattern}_dr_${modified_dropout_rate}_${modified_aug_loss}_${modified_aug_loss_weight}" && runs+=("${run_name}")
  logging_dir=./logs/${run_name} && mkdir -p ${logging_dir}
  log_pth=${logging_dir}/log.txt && touch ${log_pth}
  output_dir=./checkpoints/${run_name} && mkdir -p ${output_dir}

  # 1. Train GPT-2 Medium with LoRA (see our paper for hyperparameters for GPT-2 Medium)
  echo "$(date "+%Y%m%d-%H%M%S"): Start training GPT-2 Medium with LoRA on WebNLG dataset..." >>${log_pth}
  port=${ports[$((GPU_ID * 6 + seed * 2 + 0))]}
  CUDA_VISIBLE_DEVICES=${GPU_ID} \
    ${python_path} -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} src/gpt2_ft.py \
    --random_seed ${seed} \
    --modified_aug_loss=${modified_aug_loss} \
    --modified_aug_loss_weight=${modified_aug_loss_weight} \
    --modified_dropout_pattern ${modified_dropout_pattern} \
    --modified_dropout_rate ${modified_dropout_rate} \
    --modified_disable_ori_dropout \
    --train_data ./data/webnlg_challenge_2017/train.jsonl \
    --valid_data ./data/webnlg_challenge_2017/valid.jsonl \
    --train_batch_size 8 \
    --grad_acc 1 \
    --valid_batch_size 4 \
    --seq_len 512 \
    --model_card gpt2.md \
    --init_checkpoint ./pretrained_checkpoints/gpt2-medium-pytorch_model.bin \
    --platform local \
    --clip 0.0 \
    --lr 0.0002 \
    --weight_decay 0.01 \
    --correct_bias \
    --adam_beta2 0.999 \
    --scheduler linear \
    --warmup_step 500 \
    --max_epoch 5 \
    --save_interval 2000 \
    --eval_interval 2000 \
    --lora_dim 4 \
    --lora_alpha 32 \
    --lora_dropout 0.1 \
    --label_smooth 0.1 \
    --work_dir ${output_dir} |
    tee -a ${log_pth}

  echo "$(date "+%Y%m%d-%H%M%S"): Finish training GPT-2 Medium with LoRA on WebNLG dataset..." >>${log_pth}

  (
    # 2. Generate outputs from the trained model using beam search:
    output_idx_file=${logging_dir}/predict.11270.jsonl

    echo "$(date "+%Y%m%d-%H%M%S"): Start generating outputs from the trained model using beam search..." >>${log_pth}
    port=${ports[$((GPU_ID * 6 + seed * 2 + 1))]}
    CUDA_VISIBLE_DEVICES=${GPU_ID} \
      ${python_path} -m torch.distributed.launch --nproc_per_node=1 --master_port ${port} src/gpt2_beam.py \
      --data ./data/webnlg_challenge_2017/test.jsonl \
      --batch_size 1 \
      --seq_len 512 \
      --eval_len 64 \
      --model_card gpt2.md \
      --init_checkpoint ${output_dir}/model.11270.pt \
      --platform local \
      --lora_dim 4 \
      --lora_alpha 32 \
      --beam 10 \
      --length_penalty 0.8 \
      --no_repeat_ngram_size 4 \
      --repetition_penalty 1.0 \
      --eos_token_id 628 \
      --work_dir ${output_dir} \
      --output_file ./../../${output_idx_file} |
      tee -a ${log_pth}
    echo "$(date "+%Y%m%d-%H%M%S"): Finish generating outputs from the trained model using beam search..." >>${log_pth}

    # 3. Decode outputs from step (2)
    output_ref_file_a=${logging_dir}/a/ && mkdir -p ${output_ref_file_a}
    output_ref_file_s=${logging_dir}/s/ && mkdir -p ${output_ref_file_s}
    output_ref_file_u=${logging_dir}/u/ && mkdir -p ${output_ref_file_u}
    output_pred_file_a=${logging_dir}/a/webnlg_pred.txt
    output_pred_file_s=${logging_dir}/s/webnlg_pred.txt
    output_pred_file_u=${logging_dir}/u/webnlg_pred.txt
    output_idx_file_a=${logging_dir}/a/predict.11270.jsonl
    output_idx_file_s=${logging_dir}/s/predict.11270.jsonl
    output_idx_file_u=${logging_dir}/u/predict.11270.jsonl
    test_formatted=./data/webnlg_challenge_2017/test_formatted.jsonl
    test_formatted_a=${logging_dir}/a/test_formatted.jsonl
    test_formatted_s=${logging_dir}/s/test_formatted.jsonl
    test_formatted_u=${logging_dir}/u/test_formatted.jsonl

    cat ${test_formatted} >${test_formatted_a}
    head -n 2495 ${test_formatted} >${test_formatted_s}
    tail -n +2496 ${test_formatted} >${test_formatted_u}
    cat ${output_idx_file} >${output_idx_file_a}
    awk -F ": " '$2+0 < 2495' ${output_idx_file} >${output_idx_file_s}
    awk -F ": " '$2+0 >= 2495' ${output_idx_file} >${output_idx_file_u} &&
      python src/minus_id.py --file ${output_idx_file_u} --value 2495

    echo "$(date "+%Y%m%d-%H%M%S"): Start decoding outputs from step (2)..." >>${log_pth}
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      ${python_path} src/gpt2_decode.py \
      --vocab ./vocab \
      --sample_file ${output_idx_file_a} \
      --input_file ${test_formatted_a} \
      --ref_type webnlg \
      --ref_num 6 \
      --output_ref_file ${output_ref_file_a} \
      --output_pred_file ${output_pred_file_a} \
      --tokenize --lower |
      tee -a ${log_pth}
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      ${python_path} src/gpt2_decode.py \
      --vocab ./vocab \
      --sample_file ${output_idx_file_s} \
      --input_file ${test_formatted_s} \
      --ref_type webnlg \
      --ref_num 6 \
      --output_ref_file ${output_ref_file_s} \
      --output_pred_file ${output_pred_file_s} \
      --tokenize --lower |
      tee -a ${log_pth}
    CUDA_VISIBLE_DEVICES=$GPU_ID \
      ${python_path} src/gpt2_decode.py \
      --vocab ./vocab \
      --sample_file ${output_idx_file_u} \
      --input_file ${test_formatted_u} \
      --ref_type webnlg \
      --ref_num 6 \
      --output_ref_file ${output_ref_file_u} \
      --output_pred_file ${output_pred_file_u} \
      --tokenize --lower |
      tee -a ${log_pth}
    echo "$(date "+%Y%m%d-%H%M%S"): Finish decoding outputs from step (2)..." >>${log_pth}

    # 4. Run evaluation on WebNLG test set
    echo "$(date "+%Y%m%d-%H%M%S"): Start running evaluation on WebNLG test set..." >>${log_pth}

    echo "All test samples..." >>${log_pth}
    output_ref_file=$(realpath ${output_ref_file_a})
    output_pred_file=$(realpath ${output_pred_file_a})
    CUDA_VISIBLE_DEVICES=$GPU_ID ${python_path} ./eval/GenerationEval/eval.py \
      -R ${output_ref_file}/reference \
      -H ${output_pred_file} \
      -nr 6 \
      -m bleu,meteor,ter |
      tee -a ${log_pth}

    echo "Seen test samples..." >>${log_pth}
    output_ref_file=$(realpath ${output_ref_file_s})
    output_pred_file=$(realpath ${output_pred_file_s})
    CUDA_VISIBLE_DEVICES=$GPU_ID ${python_path} ./eval/GenerationEval/eval.py \
      -R ${output_ref_file}/reference \
      -H ${output_pred_file} \
      -nr 6 \
      -m bleu,meteor,ter |
      tee -a ${log_pth}

    echo "Unseen test samples..." >>${log_pth}
    output_ref_file=$(realpath ${output_ref_file_u})
    output_pred_file=$(realpath ${output_pred_file_u})
    CUDA_VISIBLE_DEVICES=$GPU_ID ${python_path} ./eval/GenerationEval/eval.py \
      -R ${output_ref_file}/reference \
      -H ${output_pred_file} \
      -nr 6 \
      -m bleu,meteor,ter |
      tee -a ${log_pth}

    echo "$(date "+%Y%m%d-%H%M%S"): Finish running evaluation on WebNLG test set..." >>${log_pth}

  ) || true

done
