import logging
import os
import k2
import kaldialign
import torch
from torch import nn
from typing import Dict, Iterable, List, Optional, TextIO, Tuple, Union
from dataclasses import dataclass, field
from pathlib import Path
from collections import defaultdict

class DecodingResults:
    # timestamps[i][k] contains the frame number on which tokens[i][k]
    # is decoded
    timestamps: List[List[int]]

    # hyps[i] is the recognition results, i.e., word IDs or token IDs
    # for the i-th utterance with fast_beam_search_nbest_LG.
    hyps: Union[List[List[int]], k2.RaggedTensor]

    # scores[i][k] contains the log-prob of tokens[i][k]
    scores: Optional[List[List[float]]] = None
    
    
# @dataclass
# class Hypothesis:
#     # The predicted tokens so far.
#     # Newly predicted tokens are appended to `ys`.
#     ys: List[int]

#     # The log prob of ys.
#     # It contains only one entry.
#     log_prob: torch.Tensor

#     ac_probs: Optional[List[float]] = None

#     # timestamp[i] is the frame index after subsampling
#     # on which ys[i] is decoded
#     timestamp: List[int] = field(default_factory=list)

#     # the lm score for next token given the current ys
#     lm_score: Optional[torch.Tensor] = None

#     # the RNNLM states (h and c in LSTM)
#     state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None

#     # N-gram LM state
#     state_cost: Optional[NgramLmStateCost] = None

#     # Context graph state
#     context_state: Optional[ContextState] = None

#     num_tailing_blanks: int = 0

#     @property
#     def key(self) -> str:
#         """Return a string representation of self.ys"""
#         return "_".join(map(str, self.ys))


# class HypothesisList(object):
#     def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
#         """
#         Args:
#           data:
#             A dict of Hypotheses. Its key is its `value.key`.
#         """
#         if data is None:
#             self._data = {}
#         else:
#             self._data = data

#     @property
#     def data(self) -> Dict[str, Hypothesis]:
#         return self._data

#     def add(self, hyp: Hypothesis) -> None:
#         """Add a Hypothesis to `self`.

#         If `hyp` already exists in `self`, its probability is updated using
#         `log-sum-exp` with the existed one.

#         Args:
#           hyp:
#             The hypothesis to be added.
#         """
#         key = hyp.key
#         if key in self:
#             old_hyp = self._data[key]  # shallow copy
#             torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
#         else:
#             self._data[key] = hyp

#     def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
#         """Get the most probable hypothesis, i.e., the one with
#         the largest `log_prob`.

#         Args:
#           length_norm:
#             If True, the `log_prob` of a hypothesis is normalized by the
#             number of tokens in it.
#         Returns:
#           Return the hypothesis that has the largest `log_prob`.
#         """
#         if length_norm:
#             return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
#         else:
#             return max(self._data.values(), key=lambda hyp: hyp.log_prob)

#     def remove(self, hyp: Hypothesis) -> None:
#         """Remove a given hypothesis.

#         Caution:
#           `self` is modified **in-place**.

#         Args:
#           hyp:
#             The hypothesis to be removed from `self`.
#             Note: It must be contained in `self`. Otherwise,
#             an exception is raised.
#         """
#         key = hyp.key
#         assert key in self, f"{key} does not exist"
#         del self._data[key]

#     def filter(self, threshold: torch.Tensor) -> "HypothesisList":
#         """Remove all Hypotheses whose log_prob is less than threshold.

#         Caution:
#           `self` is not modified. Instead, a new HypothesisList is returned.

#         Returns:
#           Return a new HypothesisList containing all hypotheses from `self`
#           with `log_prob` being greater than the given `threshold`.
#         """
#         ans = HypothesisList()
#         for _, hyp in self._data.items():
#             if hyp.log_prob > threshold:
#                 ans.add(hyp)  # shallow copy
#         return ans

#     def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
#         """Return the top-k hypothesis.

#         Args:
#           length_norm:
#             If True, the `log_prob` of a hypothesis is normalized by the
#             number of tokens in it.
#         """
#         hyps = list(self._data.items())

#         if length_norm:
#             hyps = sorted(
#                 hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
#             )[:k]
#         else:
#             hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]

#         ans = HypothesisList(dict(hyps))
#         return ans

#     def __contains__(self, key: str):
#         return key in self._data

#     def __iter__(self):
#         return iter(self._data.values())

#     def __len__(self) -> int:
#         return len(self._data)

#     def __str__(self) -> str:
#         s = []
#         for key in self:
#             s.append(key)
#         return ", ".join(s)
    
    
# def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
#     """Return a ragged shape with axes [utt][num_hyps].

#     Args:
#       hyps:
#         len(hyps) == batch_size. It contains the current hypothesis for
#         each utterance in the batch.
#     Returns:
#       Return a ragged shape with 2 axes [utt][num_hyps]. Note that
#       the shape is on CPU.
#     """
#     num_hyps = [len(h) for h in hyps]

#     # torch.cumsum() is inclusive sum, so we put a 0 at the beginning
#     # to get exclusive sum later.
#     num_hyps.insert(0, 0)

#     num_hyps = torch.tensor(num_hyps)
#     row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
#     ans = k2.ragged.create_ragged_shape2(
#         row_splits=row_splits, cached_tot_size=row_splits[-1].item()
#     )
#     return ans


def greedy_search_batch(
    model: nn.Module,
    encoder_out: torch.Tensor,
    encoder_out_lens: torch.Tensor,
    blank_penalty: float = 0,
    return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
    """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
    Args:
      model:
        The transducer model.
      encoder_out:
        Output from the encoder. Its shape is (N, T, C), where N >= 1.
      encoder_out_lens:
        A 1-D tensor of shape (N,), containing number of valid frames in
        encoder_out before padding.
      return_timestamps:
        Whether to return timestamps.
    Returns:
      If return_timestamps is False, return the decoded result.
      Else, return a DecodingResults object containing
      decoded result and corresponding timestamps.
    """
    assert encoder_out.ndim == 3
    assert encoder_out.size(0) >= 1, encoder_out.size(0)

    packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
        input=encoder_out,
        lengths=encoder_out_lens.cpu(),
        batch_first=True,
        enforce_sorted=False,
    )

    device = next(model.parameters()).device

    blank_id = model.decoder.blank_id
    unk_id = getattr(model, "unk_id", blank_id)
    context_size = model.decoder.context_size

    batch_size_list = packed_encoder_out.batch_sizes.tolist()
    N = encoder_out.size(0)
    assert torch.all(encoder_out_lens > 0), encoder_out_lens
    assert N == batch_size_list[0], (N, batch_size_list)

    hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]

    # timestamp[n][i] is the frame index after subsampling
    # on which hyp[n][i] is decoded
    timestamps = [[] for _ in range(N)]
    # scores[n][i] is the logits on which hyp[n][i] is decoded
    scores = [[] for _ in range(N)]

    decoder_input = torch.tensor(
        hyps,
        device=device,
        dtype=torch.int64,
    )  # (N, context_size)

    decoder_out = model.decoder(decoder_input, need_pad=False)
    decoder_out = model.joiner.decoder_proj(decoder_out)
    # decoder_out: (N, 1, decoder_out_dim)

    encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)

    offset = 0
    for t, batch_size in enumerate(batch_size_list):
        start = offset
        end = offset + batch_size
        current_encoder_out = encoder_out.data[start:end]
        current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
        # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
        offset = end

        decoder_out = decoder_out[:batch_size]

        logits = model.joiner(
            current_encoder_out, decoder_out.unsqueeze(1), project_input=False
        )
        # logits'shape (batch_size, 1, 1, vocab_size)

        logits = logits.squeeze(1).squeeze(1)  # (batch_size, vocab_size)
        assert logits.ndim == 2, logits.shape

        if blank_penalty != 0:
            logits[:, 0] -= blank_penalty

        y = logits.argmax(dim=1).tolist()
        emitted = False
        for i, v in enumerate(y):
            if v not in (blank_id, unk_id):
                hyps[i].append(v)
                timestamps[i].append(t)
                scores[i].append(logits[i, v].item())
                emitted = True
        if emitted:
            # update decoder output
            decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
            decoder_input = torch.tensor(
                decoder_input,
                device=device,
                dtype=torch.int64,
            )
            decoder_out = model.decoder(decoder_input, need_pad=False)
            decoder_out = model.joiner.decoder_proj(decoder_out)

    sorted_ans = [h[context_size:] for h in hyps]
    ans = []
    ans_timestamps = []
    ans_scores = []
    unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
    for i in range(N):
        ans.append(sorted_ans[unsorted_indices[i]])
        ans_timestamps.append(timestamps[unsorted_indices[i]])
        ans_scores.append(scores[unsorted_indices[i]])

    if not return_timestamps:
        return ans
    else:
        return DecodingResults(
            hyps=ans,
            timestamps=ans_timestamps,
            scores=ans_scores,
        )
        
        
# def modified_beam_search(
#     model: nn.Module,
#     encoder_out: torch.Tensor,
#     encoder_out_lens: torch.Tensor,
#     context_graph: Optional[ContextGraph] = None,
#     beam: int = 4,
#     temperature: float = 1.0,
#     blank_penalty: float = 0.0,
#     return_timestamps: bool = False,
# ) -> Union[List[List[int]], DecodingResults]:
#     """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.

#     Args:
#       model:
#         The transducer model.
#       encoder_out:
#         Output from the encoder. Its shape is (N, T, C).
#       encoder_out_lens:
#         A 1-D tensor of shape (N,), containing number of valid frames in
#         encoder_out before padding.
#       beam:
#         Number of active paths during the beam search.
#       temperature:
#         Softmax temperature.
#       return_timestamps:
#         Whether to return timestamps.
#     Returns:
#       If return_timestamps is False, return the decoded result.
#       Else, return a DecodingResults object containing
#       decoded result and corresponding timestamps.
#     """
#     assert encoder_out.ndim == 3, encoder_out.shape
#     assert encoder_out.size(0) >= 1, encoder_out.size(0)

#     packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
#         input=encoder_out,
#         lengths=encoder_out_lens.cpu(),
#         batch_first=True,
#         enforce_sorted=False,
#     )

#     blank_id = model.decoder.blank_id
#     unk_id = getattr(model, "unk_id", blank_id)
#     context_size = model.decoder.context_size
#     device = next(model.parameters()).device

#     batch_size_list = packed_encoder_out.batch_sizes.tolist()
#     N = encoder_out.size(0)
#     assert torch.all(encoder_out_lens > 0), encoder_out_lens
#     assert N == batch_size_list[0], (N, batch_size_list)

#     B = [HypothesisList() for _ in range(N)]
#     for i in range(N):
#         B[i].add(
#             Hypothesis(
#                 ys=[-1] * (context_size - 1) + [blank_id],
#                 log_prob=torch.zeros(1, dtype=torch.float32, device=device),
#                 context_state=None if context_graph is None else context_graph.root,
#                 timestamp=[],
#             )
#         )

#     encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)

#     offset = 0
#     finalized_B = []
#     for t, batch_size in enumerate(batch_size_list):
#         start = offset
#         end = offset + batch_size
#         current_encoder_out = encoder_out.data[start:end]
#         current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
#         # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
#         offset = end

#         finalized_B = B[batch_size:] + finalized_B
#         B = B[:batch_size]

#         hyps_shape = get_hyps_shape(B).to(device)

#         A = [list(b) for b in B]

#         B = [HypothesisList() for _ in range(batch_size)]

#         ys_log_probs = torch.cat(
#             [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
#         )  # (num_hyps, 1)

#         decoder_input = torch.tensor(
#             [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
#             device=device,
#             dtype=torch.int64,
#         )  # (num_hyps, context_size)

#         decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
#         decoder_out = model.joiner.decoder_proj(decoder_out)
#         # decoder_out is of shape (num_hyps, 1, 1, joiner_dim)

#         # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
#         # as index, so we use `to(torch.int64)` below.
#         current_encoder_out = torch.index_select(
#             current_encoder_out,
#             dim=0,
#             index=hyps_shape.row_ids(1).to(torch.int64),
#         )  # (num_hyps, 1, 1, encoder_out_dim)

#         logits = model.joiner(
#             current_encoder_out,
#             decoder_out,
#             project_input=False,
#         )  # (num_hyps, 1, 1, vocab_size)

#         logits = logits.squeeze(1).squeeze(1)  # (num_hyps, vocab_size)

#         if blank_penalty != 0:
#             logits[:, 0] -= blank_penalty

#         log_probs = (logits / temperature).log_softmax(dim=-1)  # (num_hyps, vocab_size)

#         log_probs.add_(ys_log_probs)

#         vocab_size = log_probs.size(-1)

#         log_probs = log_probs.reshape(-1)

#         row_splits = hyps_shape.row_splits(1) * vocab_size
#         log_probs_shape = k2.ragged.create_ragged_shape2(
#             row_splits=row_splits, cached_tot_size=log_probs.numel()
#         )
#         ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)

#         for i in range(batch_size):
#             topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)

#             with warnings.catch_warnings():
#                 warnings.simplefilter("ignore")
#                 topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
#                 topk_token_indexes = (topk_indexes % vocab_size).tolist()

#             for k in range(len(topk_hyp_indexes)):
#                 hyp_idx = topk_hyp_indexes[k]
#                 hyp = A[i][hyp_idx]
#                 new_ys = hyp.ys[:]
#                 new_token = topk_token_indexes[k]
#                 new_timestamp = hyp.timestamp[:]
#                 context_score = 0
#                 new_context_state = None if context_graph is None else hyp.context_state
#                 if new_token not in (blank_id, unk_id):
#                     new_ys.append(new_token)
#                     new_timestamp.append(t)
#                     if context_graph is not None:
#                         (
#                             context_score,
#                             new_context_state,
#                         ) = context_graph.forward_one_step(hyp.context_state, new_token)

#                 new_log_prob = topk_log_probs[k] + context_score

#                 new_hyp = Hypothesis(
#                     ys=new_ys,
#                     log_prob=new_log_prob,
#                     timestamp=new_timestamp,
#                     context_state=new_context_state,
#                 )
#                 B[i].add(new_hyp)

#     B = B + finalized_B

#     # finalize context_state, if the matched contexts do not reach final state
#     # we need to add the score on the corresponding backoff arc
#     if context_graph is not None:
#         finalized_B = [HypothesisList() for _ in range(len(B))]
#         for i, hyps in enumerate(B):
#             for hyp in list(hyps):
#                 context_score, new_context_state = context_graph.finalize(
#                     hyp.context_state
#                 )
#                 finalized_B[i].add(
#                     Hypothesis(
#                         ys=hyp.ys,
#                         log_prob=hyp.log_prob + context_score,
#                         timestamp=hyp.timestamp,
#                         context_state=new_context_state,
#                     )
#                 )
#         B = finalized_B

#     best_hyps = [b.get_most_probable(length_norm=True) for b in B]

#     sorted_ans = [h.ys[context_size:] for h in best_hyps]
#     sorted_timestamps = [h.timestamp for h in best_hyps]
#     ans = []
#     ans_timestamps = []
#     unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
#     for i in range(N):
#         ans.append(sorted_ans[unsorted_indices[i]])
#         ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])

#     if not return_timestamps:
#         return ans
#     else:
#         return DecodingResults(
#             hyps=ans,
#             timestamps=ans_timestamps,
#         )


def store_transcripts(
    filename: Path, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
) -> None:
    """Save predicted results and reference transcripts to a file.

    Args:
      filename:
        File to save the results to.
      texts:
        An iterable of tuples. The first element is the cur_id, the second is
        the reference transcript and the third element is the predicted result.
        If it is a multi-talker ASR system, the ref and hyp may also be lists of
        strings.
    Returns:
      Return None.
    """
    with open(filename, "w", encoding="utf8") as f:
        for cut_id, ref, hyp in texts:
            if char_level:
                ref = list("".join(ref))
                hyp = list("".join(hyp))
            print(f"{cut_id}:\tref={ref}", file=f)
            print(f"{cut_id}:\thyp={hyp}", file=f)
        
        
        
def save_results(
    res_dir,
    test_set_name: str,
    results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
    suffix=None
):  
    test_set_wers = dict()
    for key, results in results_dict.items():
        recog_path = (
            res_dir / f"recogs-{test_set_name}-{key}-{suffix}.txt"
        )
        results = sorted(results)
        store_transcripts(filename=recog_path, texts=results)
        #store_transcripts_and_timestamps(filename=recog_path, texts=results)
        logging.info(f"The transcripts are stored in {recog_path}")

        # The following prints out WERs, per-word error statistics and aligned
        # ref/hyp pairs.
        errs_filename = (
            res_dir / f"errs-{test_set_name}-{key}-{suffix}.txt"
        )
        # results = [r[:3] for r in results]
        with open(errs_filename, "w") as f:
            wer = write_error_stats(
                f, f"{test_set_name}-{key}", results, enable_log=True
            )
            test_set_wers[key] = wer

        logging.info("Wrote detailed error stats to {}".format(errs_filename))

    test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
    errs_info = (
        res_dir / f"wer-summary-{test_set_name}-{key}-{suffix}.txt"
    )
    with open(errs_info, "w") as f:
        print("settings\tWER", file=f)
        for key, val in test_set_wers:
            print("{}\t{}".format(key, val), file=f)
            
    wer_info = (
        res_dir / f"wer-summary-all-{key}-{suffix}.txt"
    )
    if not os.path.exists(wer_info):
        with open(wer_info, "w") as f:
            print("dataset\tsettings\tWER", file=f)
    with open(wer_info, "a+") as f:
        for key, val in test_set_wers:
            print("{}\t{}\t{}".format(test_set_name,key, val), file=f)
    

    s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
    note = "\tbest for {}".format(test_set_name)
    for key, val in test_set_wers:
        s += "{}\t{}{}\n".format(key, val, note)
        note = ""
    logging.info(s)
    
    
    
def write_error_stats(
    f: TextIO,
    test_set_name: str,
    results: List[Tuple[str, str]],
    enable_log: bool = True,
    compute_CER: bool = False,
    sclite_mode: bool = False,
) -> float:
    """Write statistics based on predicted results and reference transcripts.

    It will write the following to the given file:

        - WER
        - number of insertions, deletions, substitutions, corrects and total
          reference words. For example::

              Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
              reference words (2337 correct)

        - The difference between the reference transcript and predicted result.
          An instance is given below::

            THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES

          The above example shows that the reference word is `EDISON`,
          but it is predicted to `ADDISON` (a substitution error).

          Another example is::

            FOR THE FIRST DAY (SIR->*) I THINK

          The reference word `SIR` is missing in the predicted
          results (a deletion error).
      results:
        An iterable of tuples. The first element is the cut_id, the second is
        the reference transcript and the third element is the predicted result.
      enable_log:
        If True, also print detailed WER to the console.
        Otherwise, it is written only to the given file.
    Returns:
      Return None.
    """
    subs: Dict[Tuple[str, str], int] = defaultdict(int)
    ins: Dict[str, int] = defaultdict(int)
    dels: Dict[str, int] = defaultdict(int)

    # `words` stores counts per word, as follows:
    #   corr, ref_sub, hyp_sub, ins, dels
    words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
    num_corr = 0
    ERR = "*"

    if compute_CER:
        for i, res in enumerate(results):
            cut_id, ref, hyp = res
            ref = list("".join(ref))
            hyp = list("".join(hyp))
            results[i] = (cut_id, ref, hyp)

    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
        for ref_word, hyp_word in ali:
            if ref_word == ERR:
                ins[hyp_word] += 1
                words[hyp_word][3] += 1
            elif hyp_word == ERR:
                dels[ref_word] += 1
                words[ref_word][4] += 1
            elif hyp_word != ref_word:
                subs[(ref_word, hyp_word)] += 1
                words[ref_word][1] += 1
                words[hyp_word][2] += 1
            else:
                words[ref_word][0] += 1
                num_corr += 1
    ref_len = sum([len(r) for _, r, _ in results])
    sub_errs = sum(subs.values())
    ins_errs = sum(ins.values())
    del_errs = sum(dels.values())
    tot_errs = sub_errs + ins_errs + del_errs
    tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)

    if enable_log:
        logging.info(
            f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
            f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
            f"{del_errs} del, {sub_errs} sub ]"
        )

    print(f"%WER = {tot_err_rate}", file=f)
    print(
        f"Errors: {ins_errs} insertions, {del_errs} deletions, "
        f"{sub_errs} substitutions, over {ref_len} reference "
        f"words ({num_corr} correct)",
        file=f,
    )
    print(
        "Search below for sections starting with PER-UTT DETAILS:, "
        "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
        file=f,
    )

    print("", file=f)
    print("PER-UTT DETAILS: corr or (ref->hyp)  ", file=f)
    for cut_id, ref, hyp in results:
        ali = kaldialign.align(ref, hyp, ERR)
        combine_successive_errors = True
        if combine_successive_errors:
            ali = [[[x], [y]] for x, y in ali]
            for i in range(len(ali) - 1):
                if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
                    ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
                    ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
                    ali[i] = [[], []]
            ali = [
                [
                    list(filter(lambda a: a != ERR, x)),
                    list(filter(lambda a: a != ERR, y)),
                ]
                for x, y in ali
            ]
            ali = list(filter(lambda x: x != [[], []], ali))
            ali = [
                [
                    ERR if x == [] else " ".join(x),
                    ERR if y == [] else " ".join(y),
                ]
                for x, y in ali
            ]

        print(
            f"{cut_id}:\t"
            + " ".join(
                (
                    ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
                    for ref_word, hyp_word in ali
                )
            ),
            file=f,
        )

    print("", file=f)
    print("SUBSTITUTIONS: count ref -> hyp", file=f)

    for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
        print(f"{count}   {ref} -> {hyp}", file=f)

    print("", file=f)
    print("DELETIONS: count ref", file=f)
    for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
        print(f"{count}   {ref}", file=f)

    print("", file=f)
    print("INSERTIONS: count hyp", file=f)
    for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
        print(f"{count}   {hyp}", file=f)

    print("", file=f)
    print("PER-WORD STATS: word  corr tot_errs count_in_ref count_in_hyp", file=f)
    for _, word, counts in sorted(
        [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
    ):
        (corr, ref_sub, hyp_sub, ins, dels) = counts
        tot_errs = ref_sub + hyp_sub + ins + dels
        ref_count = corr + ref_sub + dels
        hyp_count = corr + hyp_sub + ins

        print(f"{word}   {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
    return float(tot_err_rate)