from typing import Tuple, Optional
import math
import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf

from pado.nn.modules import Linear, Dropout
from pado.nn.parameter import ParameterModule
from pado.nn.transformer import BidirectionalSinusoidalPositionalEncoding
from pado.nn.utils import make_mask_by_length, make_self_attn_mask_from_mask, apply_local_mask_to_attn_mask

from pado.tasks.asr.asr_encoder import BaseASREncoder
from pado.models.conformer.conformer_layer import ConformerLayer

__all__ = ["ConformerEncoder"]


class ConformerEncoder(BaseASREncoder):

    def __init__(self,
                 num_layers: int,
                 hidden_dim: int,
                 num_heads: int,
                 conv_kernel_size: int,
                 feedforward_dim: Optional[int] = None,
                 embed_drop_prob: float = 0.0,
                 attn_drop_prob: float = 0.1,
                 proj_drop_prob: float = 0.1,
                 feedforward_drop_prob: Optional[float] = None,
                 eps: float = 1e-5,
                 momentum: float = 0.01,
                 attn_first: bool = True,
                 rel_attn: bool = True,
                 attn_bias: bool = False,
                 share_query_bias: bool = False,
                 conv_norm_type: str = "bn",
                 conv_sync_bn: bool = True,
                 conv_gn_groups: int = 2,
                 conv_partial: bool = False,
                 was: bool = False,
                 was_gamma: float = 0.5,
                 x_scaling: bool = True,
                 out_dim: Optional[int] = None,
                 left_context_length: int = -1,
                 right_context_length: int = -1,
                 pos_clamp_length: Optional[int] = None,
                 memory_efficient: bool = False) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        if feedforward_dim is None:
            feedforward_dim = hidden_dim * 4
        self.feedforward_dim = feedforward_dim

        # ---------------------------------------------------------------- #
        # Positional encoding and else
        # ---------------------------------------------------------------- #
        if rel_attn:  # only used for RelPos
            self.pos_emb = BidirectionalSinusoidalPositionalEncoding(hidden_dim, clamp_length=pos_clamp_length)
        else:  # do not use positional embedding. Often combined with attn_first=False
            self.pos_emb = None

        if x_scaling:
            self.emb_scale = math.sqrt(hidden_dim)  # known as XScaling (see TXL impl. or NeMo impl.)
        else:
            self.emb_scale = 1.0
        self.emb_drop = Dropout(embed_drop_prob, inplace=False)
        # ---------------------------------------------------------------- #
        # Body
        # ---------------------------------------------------------------- #
        self.layers = nn.ModuleList([
            ConformerLayer(hidden_dim, num_heads, conv_kernel_size, feedforward_dim,
                           attn_drop_prob, proj_drop_prob, feedforward_drop_prob=feedforward_drop_prob,
                           eps=eps, momentum=momentum, attn_first=attn_first, rel_attn=rel_attn, attn_bias=attn_bias,
                           share_query_bias=share_query_bias, conv_norm_type=conv_norm_type,
                           conv_sync_bn=conv_sync_bn, conv_gn_groups=conv_gn_groups, conv_partial=conv_partial,
                           was=was, was_gamma=was_gamma, memory_efficient=memory_efficient)
            for _ in range(self.num_layers)
        ])

        self.share_query_bias = share_query_bias
        if share_query_bias and rel_attn:
            self.query_key_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
            self.query_pos_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
        else:
            self.query_key_bias = self.query_pos_bias = None

        if out_dim is not None:
            self.out_proj = Linear(hidden_dim, out_dim)
        else:
            self.out_proj = None

        self.memory_efficient = memory_efficient
        self.left_context_length = left_context_length
        self.right_context_length = right_context_length

        self._initialize_parameters()

    @torch.no_grad()
    def _initialize_parameters(self):
        # inspired by T-FixUp encoder
        # https://github.com/layer6ai-labs/T-Fixup/blob/master/fairseq/modules/transformer_layer.py
        # conformer conv-dw is already initialized
        s_ff = 0.67 * (self.num_layers ** -0.25)
        s_v = s_ff * math.sqrt(2.0)
        for module_name, module in self.layers.named_modules():
            if isinstance(module, Linear):
                if "attn" in module_name:
                    if "v_proj" in module_name:  # v
                        module.weight.data.mul_(s_v)
                elif "conv" not in module_name:  # proj, ff, out_proj
                    module.weight.data.mul_(s_ff)

    def forward(self,
                features: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param features:        (batch_size, max_seq_length, feature_dim)
        :param lengths:         (batch_size,)
        :return:
                output:         (batch_size, max_seq_length, hidden_dim)
                out_lengths:    (batch_size,)
        """
        batch_size, seq_len, hidden_dim = features.shape
        assert hidden_dim == self.hidden_dim

        if self.memory_efficient and (features.device == torch.device("cpu")):
            raise ValueError(f"ConformerEncoder memory_efficient ON, but input is on CPU.")

        # ---------------------------------------------------------------- #
        if self.emb_scale != 1:
            features = features * self.emb_scale
        features = self.emb_drop(features)
        assert torch.max(lengths) <= seq_len
        if torch.any(torch.less_equal(lengths, 0)):
            raise ValueError(f"ConformerEncoder input length too short, consider padding.")

        if self.pos_emb is not None:
            pos_emb = self.pos_emb(seq_len)  # (b, 2s - 1, d)
            pos_emb = self.emb_drop(pos_emb)
        else:
            pos_emb = None

        # ---------------------------------------------------------------- #
        mask = make_mask_by_length(lengths, max_length=seq_len).to(features.device)  # (b, s)
        attn_mask = make_self_attn_mask_from_mask(mask).to(features.device)  # (b, s, s)
        attn_mask = apply_local_mask_to_attn_mask(attn_mask, self.left_context_length, self.right_context_length)

        # ---------------------------------------------------------------- #
        h = features
        for i, layer in enumerate(self.layers):
            h, _ = layer(h,
                         pos_emb,
                         self.query_key_bias() if (self.query_key_bias is not None) else None,
                         self.query_pos_bias() if (self.query_pos_bias is not None) else None,
                         attn_mask=attn_mask,
                         conv_mask=mask)

        output = h  # (b, s, d)
        if self.out_proj is not None:
            output = self.out_proj(output)
        return output, lengths

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "ConformerEncoder":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
