from typing import Tuple, List
from collections import defaultdict
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf

try:
    from ctcdecode import CTCBeamDecoder

    CTC_DECODE_AVAILABLE = True
except ImportError:
    CTCBeamDecoder = None
    CTC_DECODE_AVAILABLE = False

from pado.tasks.asr.utils import BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE

__all__ = ["ASRCTCBeamsearch"]

_NEG_INF = -float("inf")


class ASRCTCBeamsearch(object):

    def __init__(self,
                 blank_idx: int = 0,
                 temperature: float = 1.0,
                 alpha: float = 0.0,
                 beta: float = 0.0,
                 cutoff_top_n: int = 40,
                 cutoff_prob: float = 1.0) -> None:
        self.blank_idx = blank_idx
        self.temperature = temperature
        self.alpha = alpha
        self.beta = beta
        self.cutoff_top_n = cutoff_top_n
        self.cutoff_prob = cutoff_prob

    @torch.no_grad()
    def __call__(self,
                 features: torch.Tensor,
                 lengths: torch.Tensor,
                 beam_width: int) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        """
        CTC beam decoding with width=1 is not the same as greedy decoding (as expected).
        To force beamsearch for beam_width=1 (almost never needed), you should manually call _beamsearch_decode.
        :param features:        (batch_size, length, vocab_size)
        :param lengths:         (batch_size,)
        :return:
        """
        if beam_width <= 1:
            # greedy decoding and beam_width == 1 are slightly different, but we decided to use greedy decode.
            prediction, logp_score = self._greedy_decode(features, lengths)
        elif CTC_DECODE_AVAILABLE:
            prediction, logp_score = self._beamsearch_decode_lib(features, lengths, beam_width)
        else:
            prediction, logp_score = self._beamsearch_decode(features, lengths, beam_width)

        return prediction, logp_score

    def _greedy_decode(self,
                       features: torch.Tensor,
                       lengths: torch.Tensor) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        # ---------------------------------------------------------------- #
        # https://github.com/NVIDIA/NeMo/blob/master/nemo/collections/asr/metrics/wer_bpe.py
        # WERBPE.ctc_decoder_predictions_tensor
        # ---------------------------------------------------------------- #
        batch_size, _, vocab_size = features.shape
        assert lengths.shape[0] == batch_size

        prediction = [[] for _ in range(batch_size)]
        logp_scores = [0.0 for _ in range(batch_size)]

        # ---------------------------------------------------------------- #
        features = F.log_softmax(features.float().div_(self.temperature), dim=-1)

        # per-sample
        for batch_idx in range(batch_size):
            length = lengths[batch_idx].item()
            feat = features[batch_idx, :length]  # (s, V)
            max_scores, max_tokens = torch.max(feat, dim=-1)  # (s,), (s,)

            logp_scores[batch_idx] = torch.sum(max_scores).item()

            tokens = max_tokens.tolist()
            current_token = self.blank_idx
            for t in tokens:
                if ((t != current_token) or (current_token == self.blank_idx)) and (t != self.blank_idx):
                    prediction[batch_idx].append(t)
                current_token = t

        return prediction, logp_scores

    def _beamsearch_decode_lib(self,
                               features: torch.Tensor,
                               lengths: torch.Tensor,
                               beam_width: int) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        assert CTC_DECODE_AVAILABLE

        batch_size, _, vocab_size = features.shape
        assert lengths.shape[0] == batch_size
        features = F.log_softmax(features.float().div_(self.temperature), dim=-1)

        decoder = CTCBeamDecoder(labels=[str(i) for i in range(vocab_size)],
                                 alpha=self.alpha, beta=self.beta,
                                 cutoff_top_n=self.cutoff_top_n, cutoff_prob=self.cutoff_prob,
                                 beam_width=beam_width, blank_id=self.blank_idx, log_probs_input=True)
        beam_results, beam_scores, _, out_lengths = decoder.decode(features, seq_lens=lengths)

        prediction = []
        logp_scores = []

        for i in range(batch_size):
            prediction.append(beam_results[i][0][:out_lengths[i][0]].tolist())  # take topmost hypothesis
            logp_scores.append(float(-beam_scores[i][0]))

        return prediction, logp_scores

    def _beamsearch_decode(self,
                           features: torch.Tensor,
                           lengths: torch.Tensor,
                           beam_width: int) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        # ---------------------------------------------------------------- #
        batch_size, _, vocab_size = features.shape
        assert lengths.shape[0] == batch_size

        prediction = [[] for _ in range(batch_size)]
        logp_scores = [0.0 for _ in range(batch_size)]

        # ---------------------------------------------------------------- #
        # https://gist.github.com/awni/56369a90d03953e370f3964c826ed4b0
        # https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/decoders/ctc.py
        #
        # ---------------------------------------------------------------- #
        features = F.log_softmax(features.float().div_(self.temperature), dim=-1)

        def logsumexp(*values) -> float:
            res = torch.logsumexp(torch.tensor(values), dim=0)
            return res.item()

        # TODO this impl. have no cutoff, and can be improved by beam-batch processing.
        # per-sample
        for batch_idx in range(batch_size):
            length = lengths[batch_idx].item()
            feat = features[batch_idx, :length]  # (s, V)

            # initial beam prob.
            beam = [(tuple(), (0.0, _NEG_INF))]  # prefix, (p_blank, p_non_blank)

            for t in range(length):
                next_beam = defaultdict(lambda: (_NEG_INF, _NEG_INF))

                for v in range(vocab_size):
                    p = feat[t, v].item()

                    for prefix, (p_b, p_nb) in beam:
                        if v == self.blank_idx:
                            n_p_b, n_p_nb = next_beam[prefix]
                            n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p)
                            next_beam[prefix] = (n_p_b, n_p_nb)
                            continue

                        end_t = prefix[-1] if prefix else None
                        n_prefix = prefix + (v,)
                        n_p_b, n_p_nb = next_beam[n_prefix]
                        if v != end_t:
                            n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
                        else:
                            n_p_nb = logsumexp(n_p_nb, p_b + p)

                        # we add LM here
                        next_beam[n_prefix] = (n_p_b, n_p_nb)

                        if v == end_t:
                            n_p_b, n_p_nb = next_beam[prefix]
                            n_p_nb = logsumexp(n_p_nb, p_nb + p)
                            next_beam[prefix] = (n_p_b, n_p_nb)

                beam = sorted(next_beam.items(),
                              key=lambda x: logsumexp(x[1][0], x[1][1]),
                              reverse=True)
                beam = beam[:beam_width]

            best_beam = beam[0]
            prediction[batch_idx] = list(best_beam[0])
            logp_scores[batch_idx] = logsumexp(best_beam[1][0], best_beam[1][1])

        return prediction, logp_scores

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "ASRCTCBeamsearch":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
