#!/bin/bash

stage=0

# relaxed attention related parameters
relaxAttn=0.2
relaxSelfAttn=0.01
relaxation_matched_inference=false
attention_sigmoid_smoothing=false

# set paths
DATA=examples/translation/iwslt14.tokenized.de-en
BIN=data-bin/iwslt14.tokenized.de-en
RESULT=results/transformer_small_iwslt14_de2en_cutoff

if [ ${stage} -le 0 ]; then
	echo "Stage 0: Preprocess/binarize the data..."
    # Download and prepare the data
    cd examples/translation/
    bash prepare-iwslt14.sh
    cd ../..
    
    # Binarize the dataset
    python fairseq_cli/preprocess.py \
    --source-lang de --target-lang en \
    --trainpref $DATA/train --validpref $DATA/valid --testpref $DATA/test \
    --destdir $BIN --thresholdtgt 0 --thresholdsrc 0 \
    --workers 20 --joined-dictionary
fi

if [ ${stage} -le 1 ]; then
    echo "Stage 1: Model Training"
    if $relaxation_matched_inference; then 
      opts="$opts --relaxation-matched-inference"
    fi
    
    if $attention_sigmoid_smoothing; then 
      opts="$opts --attention-sigmoid-smoothing"
    fi
      
    # Train the model w/ cutoff
    mkdir -p $RESULT
    CUDA_VISIBLE_DEVICES=0 fairseq-train $BIN \
    --arch transformer_iwslt_de_en \
    --augmentation \
    --augmentation_schema cut_off \
    --augmentation_masking_schema word \
    --augmentation_masking_probability 0.05 \
    --augmentation_replacing_schema mask \
    --share-all-embeddings \
    --optimizer adam \
    --adam-betas '(0.9, 0.98)' \
    --adam-eps 1e-9 \
    --clip-norm 0.0 \
    --weight-decay 0.0001 \
    --criterion label_smoothed_cross_entropy_with_regularization \
    --regularization_weight 5.0 \
    --label-smoothing 0.1 \
    --max-tokens 2048 \
    --dropout 0.3 \
    --attention-dropout 0.1 \
    --activation-dropout 0.1 \
    --lr-scheduler inverse_sqrt \
    --lr 7e-4 \
    --warmup-updates 6000 \
    --max-epoch 100 \
    --update-freq 1 \
    --ddp-backend=c10d \
    --keep-last-epochs 20 \
    --log-format tqdm \
    --log-interval 100 \
    --save-dir $RESULT \
    --seed 1 \
    --relaxed-self-attention-weight ${relaxSelfAttn} \
    --relaxed-attention-weight ${relaxAttn} $opts
fi

if [ ${stage} -le 2 ]; then
  echo "Stage 2: Decoding/Generation"

    # Evaluate
    RESULT=results/transformer_small_iwslt14_de2en
    # RESULT=results/transformer_small_iwslt14_de2en_cutoff
    # average last 5 checkpoints
    python scripts/average_checkpoints.py \
    --inputs $RESULT \
    --num-epoch-checkpoints 5 \
    --output $RESULT/checkpoint_last5.pt
    # generate results & quick evaluate
    LC_ALL=C.UTF-8 CUDA_VISIBLE_DEVICES=0 fairseq-generate $BIN \
    --path $RESULT/checkpoint_last5.pt \
    --beam 5 --remove-bpe --lenpen 0.5 >> $RESULT/checkpoint_last5.gen
    # compound split & re-run evaluate
    bash scripts/compound_split_bleu.sh $RESULT/checkpoint_last5.gen
    LC_ALL=C.UTF-8 python fairseq_cli/score.py \
    --sys $RESULT/checkpoint_last5.gen.sys \
    --ref $RESULT/checkpoint_last5.gen.ref
fi