# evaluate phase 2

export CUDA_VISIBLE_DEVICES=0

src_lang=en
tgt_lang=de

ckpt=$1

path_to_data=
path_to_ckpt=

DATA_ROOT=${path_to_data}/${src_lang}-${tgt_lang}

adapter=e2e # or cascade

SAVE_DIR=${path_to_ckpt}/checkpoints/dcm_zs_${src_lang}_asr_for_${tgt_lang}_${adapter}_shrink_ot10

if [ "$ckpt" = "best" ]; then
  CHECKPOINT_FILENAME=checkpoint_${ckpt}.pt
else
  CHECKPOINT_FILENAME=avg_last_10_${ckpt}_checkpoint.pt
fi
model=${SAVE_DIR}/${CHECKPOINT_FILENAME}

if [ ! -f ${model} ];then
  python scripts/average_checkpoints.py --inputs ${SAVE_DIR} --num-${ckpt}-checkpoints 10 --output ${model}
fi

data_config=config_text_zs.yaml
task=dcm
infer_results=${SAVE_DIR}/infer_results/${ckpt}
max_tokens=50000

testsets=tst-COMMON_asr,tst-HE_asr

python fairseq_cli/validate.py ${DATA_ROOT} \
  --config-yaml ${data_config} \
  --user-dir examples/dcm \
  --valid-subset ${testsets} \
  --task ${task} \
  --path ${model} \
  --max-tokens ${max_tokens} \
  --results-path ${infer_results} \
  --criterion guided_label_smoothed_cross_entropy_with_ctc \
  --ctc-post-process sentencepiece --ctc-greedy-out \
  --model-overrides '{"load_pretrained_text_encoder_from": None}'

for subset in $(echo ${testsets} | tr "," "\n")
do
    res_file=${infer_results}/generate-${subset}.txt
    cut -f1 ${res_file} > tmp.sys
    cut -f2 ${res_file} > tmp.ref
    python wer.py -s tmp.sys -r tmp.ref --wer-tokenizer 13a --wer-lowercase --wer-remove-punct
    rm -f tmp.*
done
