#!/bin/bash
data_dir=data-bin/wmt14_en_de
checkpoint_dir=./ckpt/bi_hmm_wmt14_de_en

export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

SRC_LANG=de
TGT_LANG=en

fairseq-train ${data_dir}  \
    --user-dir fs_plugins --source-lang ${SRC_LANG} --target-lang ${TGT_LANG} \
    \
    --task translation_lev_modified  --noise full_mask \
    \
    --arch bi_dag_nat \
    --decoder-learned-pos --encoder-learned-pos \
    --share-all-embeddings --activation-fn gelu \
    --apply-bert-init \
    --links-feature feature:position --decode-strategy viterbi --normalize-length \
    --max-source-positions 150 --max-target-positions 1200 --src-upsample-scale 8.0 \
    \
    --criterion bi_dag_nat_loss --train-mode sum --no-force-emit \
    --length-loss-factor 0 --max-transition-length -1 \
    --glat-p 0.5:0.1@200k --glance-strategy number-random \
    \
    --optimizer adam --adam-betas '(0.9,0.999)' --fp16 \
    --label-smoothing 0.0 --weight-decay 0.01 --dropout 0.1 \
    --lr-scheduler inverse_sqrt  --warmup-updates 10000   \
    --clip-norm 0.1 --lr 0.0005 --warmup-init-lr '1e-07' --stop-min-lr '1e-09' \
    --ddp-backend c10d \
    \
    --max-tokens 4096  --update-freq 2 --distributed-world-size 8 --grouped-shuffling \
    --max-update 300000 --max-tokens-valid 4096 \
    --save-interval 1  --save-interval-updates 10000  \
    --seed 0 \
    \
    --valid-subset valid \
    --validate-interval 1       --validate-interval-updates 10000 \
    --eval-bleu --eval-bleu-args '{"iter_decode_max_iter": 0, "iter_decode_with_beam": 1}' \
    --eval-bleu-detok moses --eval-bleu-remove-bpe --eval-bleu-print-samples \
    --fixed-validation-seed 7 \
    --best-checkpoint-metric bleu --maximize-best-checkpoint-metric   \
    --keep-best-checkpoints 5 --keep-last-epochs 10 --save-dir ${checkpoint_dir} \
    --log-format 'simple' --log-interval 100 \
    --torch-dag-best-alignment \
    --torch-dag-loss   --find-unused-parameters \


# average_checkpoints after training is done
average_checkpoint_path=${checkpoint_dir}/average.pt

python3 ./fs_plugins/scripts/average_checkpoints.py --inputs ${checkpoint_dir} \
          --max-metric --best-checkpoints-metric bleu --num-best-checkpoints-metric 5 \
          --output ${average_checkpoint_path}


echo "greedy-----------------------------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"bi_greedy\"}" \
    --path ${average_checkpoint_path} --quiet 

echo "lookahead--------------------------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"bi_lookahead\",\"decode_beta\":1}" \
    --path ${average_checkpoint_path} --quiet 

echo "viterbi----------------------------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"viterbi\"}" \
    --path ${average_checkpoint_path} --quiet

echo "viterbi norm length----------------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"viterbi\",\"normalize_length\":1}" \
    --path ${average_checkpoint_path} --quiet

echo "viterbi wo emit, norm length-------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"viterbi\",\"normalize_length\":1,\"viterbi_wo_emit\":1}" \
    --path ${average_checkpoint_path} --quiet

echo "viterbi wo emit--------------------------------------------------------------------"
fairseq-generate ${data_dir} \
    --gen-subset test --user-dir fs_plugins --task translation_lev_modified \
    --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
    --remove-bpe --max-tokens 4096 --seed 0 \
    --model-overrides "{\"decode_strategy\":\"viterbi\",\"viterbi_wo_emit\":1}" \
    --path ${average_checkpoint_path} --quiet

for d in "valid" "test"; do
    for p in $(seq 0 0.1 1.0); do
        echo "On ${data}, viterbi penalty ${p}--------------------------------------------------------------"
        fairseq-generate ${data_dir} \
            --gen-subset ${d} --user-dir fs_plugins --task translation_lev_modified \
            --iter-decode-max-iter 0 --iter-decode-eos-penalty 0 --beam 1 -s de -t en \
            --remove-bpe --max-tokens 4096 --seed 0 \
            --model-overrides "{\"decode_strategy\":\"viterbi\",\"viterbi_penalty\":${p},\"normalize_length\":1}" \
            --path ${average_checkpoint_path} --quiet
done
done

