#!/usr/bin/env bash

echo 'Prepare Data'
PREFIX=/path/to/data_dir
DATA_DIR=$PREFIX/data/
TNF_DATA_DIR=$PREFIX/tnf_data/
TNF_SUBWORD_DATA_DIR=$PREFIX/tnf_subword_data/
SAVE_DIR=$PREFIX/path/to/save_dir/

echo 'Prepare training'
cd /path/to/TNF
pip install --editable .

echo 'Start training'
TNF_LAMBDA=0.5          # TNF Lambda
TNF_GAMMA=0.1           # TNF Gamma
TOTAL_UPDATES=1000000   # Total number of training steps
WARMUP_UPDATES=10000    # Warmup the learning rate over this many updates
PEAK_LR=0.0001          # Peak learning rate, adjust as needed
TOKENS_PER_SAMPLE=512   # Max sequence length
MAX_POSITIONS=512       # Num. positional embeddings (usually same as above)
MAX_SENTENCES=8         # Number of sequences per batch (batch size)
UPDATE_FREQ=4           # Increase the batch size
SEED=100                # Random seed

python train.py $DATA_DIR --num-workers 8 --ddp-backend=c10d \
       --tnf-data $TNF_DATA_DIR --tnf-subword-data $TNF_SUBWORD_DATA_DIR \
       --task tnf_masked_lm --criterion tnf_masked_lm \
       --arch tnf_base --sample-break-mode complete --tokens-per-sample $TOKENS_PER_SAMPLE \
       --optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 0.0 \
       --lr-scheduler polynomial_decay --lr $PEAK_LR \
       --warmup-updates $WARMUP_UPDATES --total-num-update $TOTAL_UPDATES \
       --dropout 0.1 --attention-dropout 0.1 --weight-decay 0.01 \
       --max-sentences $MAX_SENTENCES --update-freq $UPDATE_FREQ --seed $SEED \
       --mask-prob 0.15 --embedding-normalize \
       --max-update 1000000 --log-format simple --log-interval 10 --tensorboard-logdir . \
       --keep-updates-list 20000 50000 100000 200000 400000 600000 800000 1000000 \
       --save-interval-updates 10000 --keep-interval-updates 5 \
       --no-epoch-checkpoints --skip-invalid-size-inputs-valid-test \
       --save-dir $SAVE_DIR \
       --tnf-lambda $TNF_LAMBDA --tnf-gamma $TNF_GAMMA \
       --update-tnf-lambda 0 --tnf-emb-zero-init 1 \
       --update-tnf-emb mask --ctx windowavg --ctx-window-size 32\
       --restore-file $SAVE_DIR/checkpoint_last.pt
