#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Implement unsupervised metric for decoding hyperparameter selection:
    $$ alpha * LM_PPL + ViterbitUER(%) * 100 $$
"""
import argparse
import logging
import sys

import editdistance

logging.root.setLevel(logging.INFO)
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger(__name__)


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-s", "--hypo", help="hypo transcription", required=True)
    parser.add_argument(
        "-r", "--reference", help="reference transcription", required=True
    )
    return parser


def compute_wer(ref_uid_to_tra, hyp_uid_to_tra, g2p):
    d_cnt = 0
    w_cnt = 0
    w_cnt_h = 0
    for uid in hyp_uid_to_tra:
        ref = ref_uid_to_tra[uid].split()
        if g2p is not None:
            hyp = g2p(hyp_uid_to_tra[uid])
            hyp = [p for p in hyp if p != "'" and p != " "]
            hyp = [p[:-1] if p[-1].isnumeric() else p for p in hyp]
        else:
            hyp = hyp_uid_to_tra[uid].split()
        d_cnt += editdistance.eval(ref, hyp)
        w_cnt += len(ref)
        w_cnt_h += len(hyp)
    wer = float(d_cnt) / w_cnt
    logger.debug(
        (
            f"wer = {wer * 100:.2f}%; num. of ref words = {w_cnt}; "
            f"num. of hyp words = {w_cnt_h}; num. of sentences = {len(ref_uid_to_tra)}"
        )
    )
    return wer


def main():
    args = get_parser().parse_args()

    errs = 0
    count = 0
    with open(args.hypo, "r") as hf, open(args.reference, "r") as rf:
        for h, r in zip(hf, rf):
            h = h.rstrip().split()
            r = r.rstrip().split()
            errs += editdistance.eval(r, h)
            count += len(r)

    logger.info(f"UER: {errs / count * 100:.2f}%")


if __name__ == "__main__":
    main()


def load_tra(tra_path):
    with open(tra_path, "r") as f:
        uid_to_tra = {}
        for line in f:
            uid, tra = line.split(None, 1)
            uid_to_tra[uid] = tra
    logger.debug(f"loaded {len(uid_to_tra)} utterances from {tra_path}")
    return uid_to_tra
