# import editdistance
from jiwer.process import process_characters, process_words
from tqdm import tqdm


def compute_metrics(pred_tokens_all, pred_words_all, target_tokens_all, target_words_all):
    """Computes WER and UER given the prediction and true transcriptions"""
    already_scanned = []

    assert len(pred_tokens_all) == len(target_tokens_all), (
        f"Expected to gather the same amount of "
        f"tokens,"
        f"got: {len(pred_tokens_all)} for "
        f"predictions and"
        f"{len(target_tokens_all)} for target"
    )
    assert len(pred_words_all) == len(target_words_all), (
        f"Expected to gather the same amount of "
        f"words,"
        f"got: {len(pred_words_all)} for "
        f"predictions and"
        f"{len(target_words_all)} for target"
    )
    subs_word, dels_word, inserts_word, hits_word = 0, 0, 0, 0
    subs_char, dels_char, inserts_char, hits_char = 0, 0, 0, 0

    for pred_tokens, pred_words, target_tokens, target_words in tqdm(
        zip(pred_tokens_all, pred_words_all, target_tokens_all, target_words_all),
        total=len(pred_tokens_all),
        desc="Computing metrics...",
        leave=True,
    ):
        if target_words in already_scanned:
            print(f"Sentence {target_words} already precessed, skip.")
            continue

        word_process_out = process_words(" ".join(target_words), " ".join(pred_words))
        subs_word += word_process_out.substitutions
        dels_word += word_process_out.deletions
        inserts_word += word_process_out.insertions
        hits_word += word_process_out.hits

        if pred_tokens[0] == pred_tokens[1] == "|":
            pred_tokens = pred_tokens[1:]
        utterance_process_out = process_characters(" ".join(target_tokens), " ".join(pred_tokens))
        subs_char += utterance_process_out.substitutions
        dels_char += utterance_process_out.deletions
        inserts_char += utterance_process_out.insertions
        hits_char += utterance_process_out.hits

    wer = float(subs_word + dels_word + inserts_word) / float(hits_word + subs_word + dels_word)
    cer = float(subs_char + dels_char + inserts_char) / float(hits_char + subs_char + dels_char)

    return cer * 100, wer * 100
