from typing import Optional, Tuple, List, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from pado.tasks.asr.transducer_predictor import BaseASRTransducerPredictor
from pado.nn.modules import LSTM, Dropout, SeqDropout, LayerNorm, Linear

__all__ = ["LSTMTransducerPredictor"]


class LSTMTransducerPredictor(BaseASRTransducerPredictor):

    def __init__(self,
                 num_tokens: int,
                 embed_dim: int,
                 num_layers: int,
                 hidden_dim: int,
                 embed_drop_prob: float = 0.1,
                 rnn_drop_prob: float = 0.1,
                 word_drop_prob: float = 0.0,
                 out_norm: bool = True,
                 out_dim: Optional[int] = None,
                 eps: float = 1e-5,
                 variational_dropout: bool = False,
                 blank_idx: int = 0) -> None:
        super().__init__(num_tokens, embed_dim, word_drop_prob=word_drop_prob, blank_idx=blank_idx)

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        self.emb_scale = math.sqrt(embed_dim)

        if not variational_dropout:
            self.emb_drop = Dropout(embed_drop_prob)
            self.rnn_drop = Dropout(rnn_drop_prob)
        else:
            self.emb_drop = SeqDropout(embed_drop_prob, dim=1)
            self.rnn_drop = SeqDropout(rnn_drop_prob, dim=1)

        self.rnn = LSTM(embed_dim, hidden_dim, num_layers, bias=True, batch_first=True,
                        dropout=rnn_drop_prob, variational_dropout=variational_dropout)

        if out_norm:
            self.out_norm = LayerNorm(hidden_dim, eps=eps)
        else:
            self.out_norm = None

        if out_dim is not None:
            self.out_proj = Linear(hidden_dim, out_dim)
        else:
            self.out_proj = None
        self._initialize_parameters()

    def _initialize_parameters(self):
        nn.init.normal_(self.word_emb.weight.data, std=math.sqrt(1 / self.embed_dim))

    def forward(self,
                indices: torch.Tensor,
                lengths: torch.Tensor) -> torch.Tensor:
        """
        :param indices:     (batch_size, seq_length)    long
        :param lengths:     (batch_size,)
        :return:
                result:     (batch_size, seq_length, ...)
        """
        batch_size, _ = indices.shape

        # add blank at very front
        indices = F.pad(indices, (1, 0), mode="constant", value=self.blank_idx)  # (batch_size, max_seq + 1)
        # sequence length is moved to GPU by DDP
        lengths = lengths.cpu().add(1)

        features = self.word_emb(indices)  # (batch_size, max_seq + 1, hidden_dim)
        features = features * self.emb_scale
        features = self.emb_drop(features)

        # sequence_length is moved to GPU by DDP
        features = nn.utils.rnn.pack_padded_sequence(features, lengths, batch_first=True,
                                                     enforce_sorted=False)  # preserve order
        features, _ = self.rnn(features)
        features, _ = nn.utils.rnn.pad_packed_sequence(features, batch_first=True)
        features = self.rnn_drop(features)  # (batch_size, max_seq + 1, hidden_dim)

        if self.out_norm is not None:
            features = self.out_norm(features)
        if self.out_proj is not None:
            features = self.out_proj(features)
        return features

    def step(self,
             index: torch.Tensor,
             hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
             ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """Run a single step (sequence_length == 1)
        :param index:       (batch_size, 1)
        :param hiddens:     [(num_layers, batch_size, hidden_dim), (num_layers, batch_size, hidden_dim)]
        :return:
                result:     (batch_size, 1, ...)
                new_hiddens
        """
        features = self.word_emb(index)  # (batch_size, 1, hidden_dim)
        features = features * self.emb_scale

        # no sequence packing
        features, hiddens = self.rnn(features, hiddens)
        if self.out_norm is not None:
            features = self.out_norm(features)
        if self.out_proj is not None:
            features = self.out_proj(features)

        return features, hiddens

    def extract_hiddens(self,
                        hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]],
                        batch_indices: Optional[Union[List[int], torch.Tensor]] = None,
                        ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        if hiddens is None:
            return None
        if batch_indices is None:
            return hiddens

        hx, cx = hiddens
        new_hx = hx[:, batch_indices].contiguous()
        new_cx = cx[:, batch_indices].contiguous()
        return new_hx, new_cx

    def update_hiddens(self,
                       hiddens: Tuple[torch.Tensor, torch.Tensor],
                       new_hiddens: Optional[Tuple[torch.Tensor, torch.Tensor]],
                       batch_indices: Optional[Union[List[int], torch.Tensor]] = None,
                       ) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
        if new_hiddens is None:
            return hiddens
        if batch_indices is None:
            batch_size = hiddens[0].shape[1]
            batch_indices = list(range(batch_size))

        hiddens[0][:, batch_indices] = new_hiddens[0][:, batch_indices]
        hiddens[1][:, batch_indices] = new_hiddens[1][:, batch_indices]
        return hiddens
