from typing import Optional, Tuple, List, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pado.tasks.asr.transducer_predictor import BaseASRTransducerPredictor
from pado.nn.modules import (Dropout, SeqDropout, LayerNorm, Linear,
                             LayerNormLSTMCellOnce, LayerNormLSTMCellEach, scan_lstm_cell)

__all__ = ["LayerNormLSTMTransducerPredictor"]


class LayerNormLSTMTransducerPredictor(BaseASRTransducerPredictor):

    def __init__(self,
                 num_tokens: int,
                 embed_dim: int,
                 num_layers: int,
                 hidden_dim: int,
                 embed_drop_prob: float = 0.1,
                 rnn_drop_prob: float = 0.1,
                 word_drop_prob: float = 0.0,
                 out_norm: bool = True,
                 out_dim: Optional[int] = None,
                 eps: float = 1e-5,
                 variational_dropout: bool = False,
                 impl_type: str = "once",  # once or each
                 blank_idx: int = 0):
        super().__init__(num_tokens, embed_dim, word_drop_prob=word_drop_prob, blank_idx=blank_idx)

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        self.emb_scale = math.sqrt(embed_dim)

        if not variational_dropout:
            self.emb_drop = Dropout(embed_drop_prob)
        else:
            self.emb_drop = SeqDropout(embed_drop_prob, dim=1)

        # (LayerNormLSTMCell - LayerNorm - Drop) x N-times
        rnns = []
        rnn_dim = embed_dim
        for i in range(self.num_layers):
            if impl_type == "once":
                r = LayerNormLSTMCellOnce(rnn_dim, hidden_dim, bias=False, dropout=rnn_drop_prob,
                                          eps=eps, init_forget_bias=True, variational_dropout=variational_dropout)
            else:  # each
                r = LayerNormLSTMCellEach(rnn_dim, hidden_dim, bias=False, dropout=rnn_drop_prob,
                                          eps=eps, init_forget_bias=True, variational_dropout=variational_dropout)
            rnns.append(r)
            rnn_dim = hidden_dim
        self.rnns = nn.ModuleList(rnns)

        if out_norm:
            self.out_norm = LayerNorm(hidden_dim, eps=eps)
        else:
            self.out_norm = None

        if out_dim is not None:
            self.out_proj = Linear(hidden_dim, out_dim)
        else:
            self.out_proj = None
        self._initialize_parameters()

    def _initialize_parameters(self):
        nn.init.normal_(self.word_emb.weight.data, std=math.sqrt(1 / self.embed_dim))

    def forward(self,
                indices: torch.Tensor,
                lengths: torch.Tensor) -> torch.Tensor:
        """
        :param indices:     (batch_size, seq_length)    long
        :param lengths:     (batch_size,)
        :return:
                result:     (batch_size, seq_length, ...)
        """
        batch_size, _ = indices.shape

        # add blank at very front
        indices = F.pad(indices, (1, 0), mode="constant", value=self.blank_idx)  # (batch_size, max_seq + 1)
        # sequence length is moved to GPU by DDP
        lengths = lengths.add(1)

        features = self.word_emb(indices)  # (batch_size, max_seq + 1, hidden_dim)
        features = features * self.emb_scale
        features = self.emb_drop(features)

        for i in range(self.num_layers):
            features, _ = scan_lstm_cell(self.rnns[i], features, lengths, state=None)
            # for Transducer forward, we don't need states.
            # we only need state in 'step'.

        if self.out_norm is not None:
            features = self.out_norm(features)
        if self.out_proj is not None:
            features = self.out_proj(features)
        return features

    def step(self,
             index: torch.Tensor,
             hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
             ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Run a single step (sequence_length == 1)
        :param index:       (batch_size, 1)
        :param hiddens:     [(num_layers, batch_size, hidden_dim), (num_layers, batch_size, hidden_dim)]
        :return:
                result:     (batch_size, 1, ...)
                new_hiddens
        """
        features = self.word_emb(index)  # (batch_size, 1, hidden_dim)
        features = features * self.emb_scale

        new_hiddens = []
        new_cells = []

        for i in range(self.num_layers):
            s = (hiddens[0][i], hiddens[1][i]) if (hiddens is not None) else None  # (hx, cx) or None
            features, (hy, cy) = scan_lstm_cell(self.rnns[i], features, state=s)

            new_hiddens.append(hy)
            new_cells.append(cy)

        new_hiddens = torch.stack(new_hiddens, dim=0).detach()
        new_cells = torch.stack(new_cells, dim=0).detach()
        new_states = (new_hiddens, new_cells)

        if self.out_norm is not None:
            features = self.out_norm(features)
        if self.out_proj is not None:
            features = self.out_proj(features)
        return features, new_states

    def extract_hiddens(self,
                        hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]],
                        batch_indices: Optional[Union[List[int], torch.Tensor]] = None,
                        ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        if hiddens is None:
            return None
        if batch_indices is None:
            return hiddens

        hx, cx = hiddens
        new_hx = hx[:, batch_indices].contiguous()
        new_cx = cx[:, batch_indices].contiguous()
        return new_hx, new_cx

    def update_hiddens(self,
                       hiddens: Tuple[torch.Tensor, torch.Tensor],
                       new_hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]],
                       batch_indices: Optional[Union[List[int], torch.Tensor]] = None,
                       ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        if new_hiddens is None:
            return hiddens
        if batch_indices is None:
            batch_size = hiddens[0].shape[1]
            batch_indices = list(range(batch_size))

        hiddens[0][:, batch_indices] = new_hiddens[0][:, batch_indices]
        hiddens[1][:, batch_indices] = new_hiddens[1][:, batch_indices]
        return hiddens
