# config for 4 V100 32G
export CUDA_VISIBLE_DEVICES=0,1,2,3 

src_lang=en
tgt_lang=de

path_to_data=
path_to_ckpt=

DATA_DIR=${path_to_data}/${src_lang}-${tgt_lang}

pretrain_nmt=${path_to_ckpt}/checkpoints/spm_${src_lang}_spm_${tgt_lang}_wmt_only_nmt_enc6/avg_last_10_update_checkpoint.pt

adapter=e2e

SAVE_DIR=${path_to_ckpt}/checkpoints/dcm_zs_${src_lang}_asr_for_${tgt_lang}_${adapter}_shrink_ot10

mkdir -p ${SAVE_DIR}

train_data=train_asr,train-clean-100,train-clean-360,train-other-500
dev_data=dev_asr

python train.py ${DATA_DIR} \
    --save-dir ${SAVE_DIR} \
    --config-yaml config_text_zs.yaml \
    --train-subset ${train_data} --valid-subset ${dev_data} \
    --num-workers 8 \
    --task dcm \
    --arch s2t_dcm --share-decoder-input-output-embed \
    --user-dir examples/dcm \
    --max-epoch 60 --update-mix-data \
    --optimizer adam --lr-scheduler inverse_sqrt \
    --lr 0.002 --update-freq 2 --clip-norm 10.0 \
    --criterion guided_label_smoothed_cross_entropy_with_ctc --zero-shot \
    --ctc-weight 1.0 --zero-infinity \
    --label-smoothing 0.1 --max-tokens 40000 --max-tokens-text 10000 \
    --max-positions-text 400 --seed 2 \
    --encoder-layers 12 --text-encoder-layers 6 --decoder-layers 6 \
    --dropout 0.1 --warmup-updates 10000 \
    --text-input-cost-ratio 0.5 --enc-grad-mult 1.0 \
    --langpairs ${src_lang}-${tgt_lang} --noise-token '"'"'▁NOISE'"'"' \
    --mask-text-ratio 0.0 --max-tokens-valid 5000 --ddp-backend no_c10d \
    --log-interval 100 --data-buffer-size 50 \
    --best-checkpoint-metric wer \
    --save-interval-updates 1000 --keep-interval-updates 10 --keep-last-epochs 10 \
    --load-pretrained-text-encoder-from ${pretrain_nmt} \
    --load-pretrained-decoder-from ${pretrain_nmt} \
    --encoder-freeze-module text_encoder,embed_adapter \
    --decoder-freeze-module decoder \
    --shrink-ctc --adapter ${adapter} --ot-weight 10. > ${SAVE_DIR}/dcm.log 2>&1

