TASK='iwslt14'
MODE='train'
WORLD_SIZE=$2
PORT=24388

SRC_LANG='en' 
TRG_LANG='de'
AHEAD=1000 #18000
SEED=0

JOINED_DICT=1
if [ $JOINED_DICT == 1 ]
then
    DICT1='vocab.10000.joined.pkl'
    DICT2=$DICT1
else
    DICT1=''
    DICT2=''
fi
SAVE_DIR='./results/'
DATA_DIR='./data/'

TRAIN_FILE1='bpe.train.'$SRC_LANG
TRAIN_FILE2='bpe.train.'$TRG_LANG
VALID_FILE1='bpe.valid.'$SRC_LANG
VALID_FILE2='bpe.valid.'$TRG_LANG
TEST_BATCH_SIZE=100
TOKEN_SIZE=3600 #3600 # for token-based dataloader
TEST_TOKEN_SIZE=3000 #8000 #9000
SORTING=1
MAX_LENGTH=150
TEST_MAX_LENGTH=150

#LIMIT='500k'
OPTIMIZER='radam'
LR=0.0005
OPT_SCHEDULED=0
DATA_TYPE='fp32' # fp32
OPT_START=4000 #warmup steps (4000 is default)
LABEL_SMOOTHING=0.1
UPDATE_STEP=1
GRAD_CLIP=1.0
PATIENCE=50 #30
N_CHECKPOINT=8 # 8 -> (8+2) checkpoint ensemble

MODEL=$3  # 'bt', 'tm', 'tm_big', 'rnnnmt'
DIM_MODEL=512
DIM_WEMB=$DIM_MODEL # should be the same
DROPOUT_P=0.3
DIM_FF=1024
N_LAYERS=6
TM_N_HEAD=4 #4
TM_RESNORM_TYPE='norm_res' # 'norm_res' : our default, 'res_norm' : (A. Vaswani 2017) default

# NegAtt
N_HEAD_NEG=1
NEGATT_MODE='separam' # const, separam
POS_LAMBDA=1.0
NEG_LAMBDA=1.0
NEGATT_APPLY='full'
NEG_KEY=0

echo "bi training"

CUDA_VISIBLE_DEVICES=$1 torchrun --rdzv_backend=c10d --rdzv_endpoint=localhost:0 --nnodes=1\
        --nproc_per_node=$WORLD_SIZE nmt_run.py \
        --translation_task=$TASK --mode=$MODE --world_size=$WORLD_SIZE --port=$PORT \
        --src_lang=$SRC_LANG --trg_lang=$TRG_LANG --ahead=$AHEAD \
        --dataset_seed=$SEED --src_dict=$DICT1 --trg_dict=$DICT2 --save_dir=$SAVE_DIR \
        --data_dir=$DATA_DIR --train_src_file=$TRAIN_FILE1 --train_trg_file=$TRAIN_FILE2 \
        --valid_src_file=$VALID_FILE1 --valid_trg_file=$VALID_FILE2 \
        --test_batch_size=$TEST_BATCH_SIZE --token_size=$TOKEN_SIZE \
        --test_token_size=$TEST_TOKEN_SIZE --sorting=$SORTING \
        --max_length=$MAX_LENGTH --test_max_length=$TEST_MAX_LENGTH \
        --lr=$LR --opt_scheduled=$OPT_SCHEDULED \
        --optimizer=$OPTIMIZER --opt_start=$OPT_START --data_type=$DATA_TYPE \
        --label_smoothing=$LABEL_SMOOTHING --update_step=$UPDATE_STEP --grad_clip=$GRAD_CLIP \
        --patience=$PATIENCE --n_checkpoint=$N_CHECKPOINT \
        --model=$MODEL --dim_model=$DIM_MODEL --dim_wemb=$DIM_WEMB --dropout_p=$DROPOUT_P \
        --tm_dim_ff=$DIM_FF --tm_n_layers=$N_LAYERS --tm_n_head=$TM_N_HEAD \
        --tm_resnorm_type=$TM_RESNORM_TYPE --pos_lambda=$POS_LAMBDA \
        --n_head_neg=$N_HEAD_NEG --negatt_mode=$NEGATT_MODE --neg_lambda=$NEG_LAMBDA \
        --negatt_apply=$NEGATT_APPLY --neg_key=$NEG_KEY \
        --print_every=25 --valid_start=250 --valid_every=250
