from typing import Tuple, Optional, Union
import math

import torch
import torch.nn as nn
from omegaconf import DictConfig, OmegaConf

from pado.nn.modules import Linear, Dropout
from pado.nn.parameter import ParameterModule
from pado.nn.transformer import BidirectionalSinusoidalPositionalEncoding
from pado.nn.utils import make_mask_by_length, make_self_attn_mask_from_mask, apply_local_mask_to_attn_mask
from pado.tasks.asr.asr_encoder import BaseASREncoder
from conformer.conformer_layer import ConformerLayerWithV
from conformer.conformer_layer_reuse import ReuseAttnConformerLayerWithV

__all__ = ["ReuseAttnConformerEncoder"]


class ReuseAttnConformerEncoder(BaseASREncoder):

    def __init__(self,
                 num_layers: int,
                 hidden_dim: int,
                 num_heads: Union[int, Tuple[int, int]],
                 conv_kernel_size: int,
                 feedforward_dim: Optional[int] = None,
                 embed_drop_prob: float = 0.0,
                 attn_drop_prob: float = 0.1,
                 proj_drop_prob: float = 0.1,
                 feedforward_drop_prob: Optional[float] = None,
                 eps: float = 1e-5,
                 momentum: float = 0.01, *,
                 attn_first: bool = True,
                 rel_attn: bool = True,
                 attn_bias: bool = False,
                 share_query_bias: bool = False,
                 conv_norm_type: str = "bn",
                 conv_sync_bn: bool = True,
                 conv_gn_groups: int = 2,
                 conv_partial: bool = False,
                 was: bool = False,
                 was_gamma: float = 0.5,
                 x_scaling: bool = True,
                 out_dim: Optional[int] = None,
                 left_context_length: int = -1,
                 right_context_length: int = -1,
                 pos_clamp_length: Optional[int] = None,
                 memory_efficient: bool = False,
                 share_v_proj_dim: Optional[int] = None,
                 share_pattern: Optional[Tuple[bool]] = None) -> None:
        super().__init__()

        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        if isinstance(num_heads, int):
            num_heads = [num_heads, num_heads]
        if len(num_heads) != 2:
            raise ValueError(f"Num heads length {len(num_heads)} is not 2.")
        self.num_heads = num_heads
        self.use_diff_num_heads = (self.num_heads[0] != self.num_heads[1])

        if feedforward_dim is None:
            feedforward_dim = hidden_dim * 4
        self.feedforward_dim = feedforward_dim

        # ---------------------------------------------------------------- #
        # Positional encoding and else
        # ---------------------------------------------------------------- #
        if rel_attn:  # only used for RelPos
            self.pos_emb = BidirectionalSinusoidalPositionalEncoding(hidden_dim, clamp_length=pos_clamp_length)
        else:  # do not use positional embedding. Often combined with attn_first=False
            self.pos_emb = None

        if x_scaling:
            self.emb_scale = math.sqrt(hidden_dim)  # known as XScaling (see TXL impl. or NeMo impl.)
        else:
            self.emb_scale = 1.0
        self.emb_drop = Dropout(embed_drop_prob, inplace=False)
        # ---------------------------------------------------------------- #
        # Body
        # ---------------------------------------------------------------- #
        # share pattern True -> create attention scores
        # share pattern False -> reuse attention scores
        if share_pattern is None:
            share_pattern = tuple([True] * num_layers)
        if len(share_pattern) != num_layers:
            raise ValueError(f"Sharing pattern length {len(share_pattern)} mismatch to num_layers {num_layers}")
        if not share_pattern[0]:
            raise ValueError("First share pattern should be True")

        self.share_pattern = share_pattern
        layers = []
        for nh, sp in enumerate(share_pattern):
            assert isinstance(sp, bool)
            if nh < self.num_layers // 2:
                num_heads = self.num_heads[0]
            else:
                num_heads = self.num_heads[1]
            if sp:  # True
                layers.append(
                    ConformerLayerWithV(
                        hidden_dim, num_heads, conv_kernel_size, feedforward_dim,
                        attn_drop_prob, proj_drop_prob, feedforward_drop_prob=feedforward_drop_prob,
                        eps=eps, momentum=momentum, attn_first=attn_first, rel_attn=rel_attn, attn_bias=attn_bias,
                        share_query_bias=share_query_bias, conv_norm_type=conv_norm_type,
                        conv_sync_bn=conv_sync_bn, conv_gn_groups=conv_gn_groups, conv_partial=conv_partial,
                        was=was, was_gamma=was_gamma, memory_efficient=memory_efficient)
                )
            else:  # False
                layers.append(
                    ReuseAttnConformerLayerWithV(
                        hidden_dim, num_heads, conv_kernel_size, feedforward_dim,
                        proj_drop_prob, feedforward_drop_prob=feedforward_drop_prob,
                        eps=eps, momentum=momentum, attn_first=attn_first, conv_norm_type=conv_norm_type,
                        conv_sync_bn=conv_sync_bn, conv_gn_groups=conv_gn_groups, conv_partial=conv_partial,
                        share_v_proj_dim=share_v_proj_dim, memory_efficient=memory_efficient)
                )

        self.layers = nn.ModuleList(layers)

        self.share_query_bias = share_query_bias
        if share_query_bias and rel_attn:
            self.query_key_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
            self.query_pos_bias = ParameterModule(torch.zeros(hidden_dim, dtype=torch.float32))
        else:
            self.query_key_bias = self.query_pos_bias = None

        if out_dim is not None:
            self.out_proj = Linear(hidden_dim, out_dim)
        else:
            self.out_proj = None

        self.memory_efficient = memory_efficient
        self.left_context_length = left_context_length
        self.right_context_length = right_context_length

        self._initialize_parameters()

    @torch.no_grad()
    def _initialize_parameters(self):
        # inspired by T-FixUp encoder
        # https://github.com/layer6ai-labs/T-Fixup/blob/master/fairseq/modules/transformer_layer.py
        # conformer conv-dw is already initialized
        s_ff = 0.67 * (self.num_layers ** -0.25)
        s_v = s_ff * math.sqrt(2.0)
        for module_name, module in self.layers.named_modules():
            if isinstance(module, Linear):
                if "attn" in module_name:
                    if "v_proj" in module_name:  # v
                        module.weight.data.mul_(s_v)
                elif "conv" not in module_name:  # proj, ff
                    module.weight.data.mul_(s_ff)

    def forward(self,
                features: torch.Tensor,
                lengths: torch.Tensor) -> Tuple[torch.Tensor, ...]:
        """
        :param features:        (batch_size, max_seq_length, feature_dim)
        :param lengths:         (batch_size,)
        :return:
                output:         (batch_size, max_seq_length, hidden_dim)
                out_lengths:    (batch_size,)
                scores:         (batch_size, num_layers, num_heads, max_seq_length, max_seq_length)
        """
        batch_size, seq_len, hidden_dim = features.shape
        assert hidden_dim == self.hidden_dim

        if self.memory_efficient and (features.device == torch.device("cpu")):
            raise ValueError(f"ReuseAttnConformerEncoder memory_efficient ON, but input is on CPU.")

        # ---------------------------------------------------------------- #
        if self.emb_scale != 1:
            features = features * self.emb_scale
        features = self.emb_drop(features)
        assert torch.max(lengths) <= seq_len
        if torch.any(torch.less_equal(lengths, 0)):
            raise ValueError(f"ReuseAttnConformerEncoder input length too short, consider padding.")

        if self.pos_emb is not None:
            pos_emb = self.pos_emb(seq_len)  # (b, 2s - 1, d)
            pos_emb = self.emb_drop(pos_emb)
        else:
            pos_emb = None

        # ---------------------------------------------------------------- #
        mask = make_mask_by_length(lengths, max_length=seq_len).to(features.device)  # (b, s)
        attn_mask = make_self_attn_mask_from_mask(mask).to(features.device)  # (b, s, s)
        attn_mask = apply_local_mask_to_attn_mask(attn_mask, self.left_context_length, self.right_context_length)

        # ---------------------------------------------------------------- #
        hiddens = []
        h = features

        hiddens.append(h.detach())

        scores = []
        values = []
        reused_scores = None
        for i, (layer, sp) in enumerate(zip(self.layers, self.share_pattern)):
            if sp:
                h, attn_scores, v = layer(h,
                                          pos_emb,
                                          self.query_key_bias() if (self.query_key_bias is not None) else None,
                                          self.query_pos_bias() if (self.query_pos_bias is not None) else None,
                                          attn_mask=attn_mask,
                                          conv_mask=mask)
                scores.append(attn_scores[0].detach())
                values.append(v[0].detach())
                reused_scores = attn_scores[0]
            else:
                assert reused_scores is not None
                h, attn_scores, v = layer(h,
                                          reused_scores,
                                          conv_mask=mask)
                scores.append(attn_scores[0].detach())
                values.append(v[0].detach())
            hiddens.append(h.detach())

        if not self.use_diff_num_heads:
            scores = torch.stack(scores, dim=1)  # (batch_size, num_layers, num_heads, seq_len, seq_len)
        else:
            scores_1 = torch.stack(scores[:self.num_layers // 2], dim=1)
            scores_2 = torch.stack(scores[self.num_layers // 2:], dim=1)
            scores = (scores_1, scores_2)
        hiddens = torch.stack(hiddens, dim=1)  # (batch_size, num_layers + 1, seq_len, hidden_dim)

        if not all(self.share_pattern):  # some are missing
            expand_values = []
            for nh, v in enumerate(values):
                if nh < self.num_layers // 2:
                    num_heads = self.num_heads[0]
                else:
                    num_heads = self.num_heads[1]
                value_dim = self.hidden_dim // num_heads * 2
                if v.shape[-1] != value_dim:
                    expand_values.append(torch.cat([v, torch.zeros_like(v)], dim=-1))
                else:
                    expand_values.append(v)
            values = expand_values

        if not self.use_diff_num_heads:
            values = torch.stack(values, dim=1)  # (batch_size, num_layers, num_heads, seq_len, head_dim)
        else:
            values_1 = torch.stack(values[:self.num_layers // 2], dim=1)
            values_2 = torch.stack(values[self.num_layers // 2:], dim=1)
            values = (values_1, values_2)

        output = h  # (b, s, d)
        if self.out_proj is not None:
            output = self.out_proj(output)
        return output, lengths, scores, values, hiddens

    @classmethod
    def from_config(cls, cfg: DictConfig) -> "ReuseAttnConformerEncoder":
        cfg = OmegaConf.to_container(cfg, resolve=True)
        return cls(**cfg)
