from typing import Optional, Tuple
import math
import torch
import torch.nn.functional as F
from omegaconf import DictConfig, OmegaConf

from pado.tasks.asr.transducer_predictor import BaseASRTransducerPredictor
from pado.tasks.asr.transducer_jointer import BaseASRTransducerJointer
from pado.tasks.asr.utils import BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE

__all__ = ["ASRTransducerBeamsearch"]


class ASRTransducerBeamsearch(object):

    def __init__(self,
                 blank_idx: int = 0,
                 score_normalize: bool = True,
                 temperature: float = 1.0,
                 max_tokens_per_step: int = 20) -> None:
        self.blank_idx = blank_idx
        self.score_normalize = score_normalize  # normalize score when sort finals
        self.temperature = temperature
        self.max_tokens_per_step = max_tokens_per_step

        # self.state_beam = state_beam
        # self.expand_beam = expand_beam
        # Default is set to 2.3:
        #   exp(-2.3) ~= 0.1
        #   Y >= X * 0.1
        #   logY >= logX + log(0.1) = logX - 2.3

        # because this class is NOT nn.module, they are not doubly-included (safe).
        self.predictor: Optional[BaseASRTransducerPredictor] = None
        self.jointer: Optional[BaseASRTransducerJointer] = None

    @torch.no_grad()
    def __call__(self,
                 features: torch.Tensor,
                 lengths: torch.Tensor,
                 beam_width: int) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        """
        :param features:        (batch_size, length, dim)
        :param lengths:         (batch_size,)
        :return:
        """
        if (self.predictor is None) or (self.jointer is None):
            raise ValueError(f"TransducerBeamsearch predictor or jointer is None.")

        _training_before_decode = self.predictor.training
        self.predictor.eval()
        self.jointer.eval()

        if beam_width <= 1:
            prediction, logp_score = self._greedy_decode(features, lengths)
        else:
            prediction, logp_score = self._beamsearch_decode(features, lengths, beam_width)

        if _training_before_decode:
            self.predictor.train()
            self.jointer.train()
        return prediction, logp_score

    def _greedy_decode(self,
                       features: torch.Tensor,
                       lengths: torch.Tensor) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        # ---------------------------------------------------------------- #
        # https://github.com/NVIDIA/NeMo/blob/master/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
        # GreedyBatchedRNNTInfer._greedy_decode_blank_as_pad
        # ---------------------------------------------------------------- #

        device = features.device
        batch_size = features.shape[0]
        assert lengths.shape[0] == batch_size
        hiddens = None

        prediction = [[] for _ in range(batch_size)]
        logp_scores = [0.0 for _ in range(batch_size)]

        last_token = torch.full((batch_size, 1), fill_value=self.blank_idx, dtype=torch.long, device=device)
        blank_flag = torch.zeros(batch_size, dtype=torch.bool, device=device)

        # ---------------------------------------------------------------- #
        max_length = lengths.max().item()
        for idx in range(max_length):
            enc = features[:, idx].unsqueeze(1)  # (b, 1, d)

            not_blank = True
            tokens_added = 0

            # reset blank flag, mask length overflowed samples
            blank_flag.fill_(False)  # why?
            blank_flag = torch.less_equal(lengths, idx)  # if idx > length, flag is True

            while not_blank and (tokens_added < self.max_tokens_per_step):
                pred, hiddens_candidate = self.predictor.step(last_token, hiddens)  # for 1st step, all blank

                # batch predictor and jointer forward
                out = self.jointer.step(enc, pred).squeeze(1)  # (b, 1, V) -> (b, V)
                out = F.log_softmax(out.float().div_(self.temperature), dim=-1)

                logp, token = torch.max(out, dim=-1)  # (b,), (b,)
                # accumulate blank mask
                # if True, do not update prediction, score and hidden.
                blank_flag = torch.bitwise_or(blank_flag, torch.eq(token, self.blank_idx))

                if torch.all(blank_flag):
                    not_blank = False  # early exit
                else:
                    if hiddens is not None:  # not first time
                        blank_indices = torch.nonzero(torch.eq(blank_flag, True), as_tuple=False)

                        # roll back for blank indices
                        hiddens_candidate = self.predictor.update_hiddens(hiddens_candidate, hiddens,
                                                                          batch_indices=blank_indices)
                        token[blank_indices] = last_token[blank_indices, 0]

                    # update new label
                    last_token = token.clone().view(-1, 1)
                    hiddens = hiddens_candidate

                    for i, (t, p) in enumerate(zip(token, logp)):
                        if blank_flag[i] == 0:
                            prediction[i].append(t.item())
                            logp_scores[i] += p.item()
                    tokens_added += 1

        return prediction, logp_scores

    def _beamsearch_decode(self,
                           features: torch.Tensor,
                           lengths: torch.Tensor,
                           beam_width: int) -> Tuple[BEAM_PREDICTION_DTYPE, BEAM_LOGP_SCORE_DTYPE]:
        if beam_width <= 1:
            return self._greedy_decode(features, lengths)
        # ---------------------------------------------------------------- #
        # https://github.com/NVIDIA/NeMo/blob/master/nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
        # BeamRNNTInfer.default_beam_search
        # https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/decoders/transducer.py
        # TransducerBeamSearcher.transducer_beam_search_decode
        # ---------------------------------------------------------------- #
        batch_size, _, feature_dim = features.shape
        assert batch_size == lengths.shape[0]

        hyp_prediction = []
        hyp_logp_score = []
        device = features.device

        def _compute_score_sort(x) -> float:
            # internal top beam selection
            return x["logp_score"]

        def _compute_score(x) -> float:
            # final beam scoring
            denominator = max(len(x["prediction"]), 1) if self.score_normalize else 1
            return x["logp_score"] / denominator

        # -------------------------------------------------------------------------------- #
        # beam search implementation per-batch.
        for batch_idx in range(batch_size):
            enc_len = lengths[batch_idx].item()
            # ------------------------------------------------------------------------ #
            # init
            token_index = torch.zeros(1, dtype=torch.long, device=device).fill_(self.blank_idx).view(1, 1)
            beam_hyps = [{
                "prediction": [self.blank_idx],
                "logp_score": 0.0,
                "hidden": None
            }]
            enc_count = 0

            while enc_count < enc_len:
                enc_seq = features[batch_idx, enc_count].view(1, 1, -1)  # (1, 1, hidden_dim)

                process_hyps = beam_hyps
                beam_hyps = []
                while True:
                    if len(beam_hyps) >= beam_width:
                        break

                    a_best_hyp = max(process_hyps, key=_compute_score_sort)
                    process_hyps.remove(a_best_hyp)

                    # forward
                    token_index[0, 0] = a_best_hyp["prediction"][-1]  # (1, 1)
                    dec_seq, hidden = self.predictor.step(token_index, a_best_hyp["hidden"])
                    out_seq = self.jointer.step(enc_seq, dec_seq).squeeze()  # (1, 1, V) -> (V,)
                    out_seq = F.log_softmax(out_seq.float().div_(self.temperature), dim=-1)  # (V,)

                    # add blank hyp
                    blank_hyp = {
                        "prediction": a_best_hyp["prediction"].copy(),
                        "logp_score": a_best_hyp["logp_score"] + out_seq[self.blank_idx].item(),
                        "hidden": a_best_hyp["hidden"]
                    }
                    beam_hyps.append(blank_hyp)
                    out_seq[self.blank_idx].fill_(-math.inf)

                    logp_targets, topk_tokens = torch.topk(out_seq, k=beam_width)
                    # best_logp = logp_targets[0]
                    assert topk_tokens[0] != self.blank_idx

                    for topk_idx in range(logp_targets.shape[0]):
                        # stop expand if we meet blank
                        if topk_tokens[topk_idx] == self.blank_idx:
                            raise RuntimeError  # blank is already handled above

                        topk_hyp = {
                            "prediction": a_best_hyp["prediction"].copy(),
                            "logp_score": a_best_hyp["logp_score"] + logp_targets[topk_idx].item(),
                            "hidden": hidden
                        }
                        topk_hyp["prediction"].append(topk_tokens[topk_idx].item())  # increased
                        process_hyps.append(topk_hyp)

                enc_count += 1

            # -------------------------------------------------------------------------------- #
            # finalize
            beam_best_hyp = max(beam_hyps, key=_compute_score)
            hyp_prediction.append(beam_best_hyp["prediction"])
            hyp_logp_score.append(beam_best_hyp["logp_score"])

        return hyp_prediction, hyp_logp_score

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "ASRTransducerBeamsearch":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
