from typing import List, Optional, Tuple, Union
import random
import copy
import warnings
import torch
from torch import nn, Tensor
from .zipformer import (
    Zipformer2, 
    Zipformer2EncoderLayer, 
    DownsampledZipformer2Encoder, 
    FeedforwardModule,
    Zipformer2Encoder,
    SimpleDownsample,
    CompactRelPositionalEncoding
    )
from ..utils.scaling import (
    ActivationDropoutAndLinear,
    Balancer,
    BiasNorm,
    ChunkCausalDepthwiseConv1d,
    Dropout2,
    FloatLike,
    ScheduledFloat,
    Whiten,
    convert_num_channels,
    limit_param_value,
    penalize_abs_values_gt,
    softmax,
)

class ZipformerMoe(Zipformer2):
    def __init__(
        self,
        output_downsampling_factor: int = 2,
        downsampling_factor: Tuple[int] = (2, 4),
        encoder_dim: Union[int, Tuple[int]] = 384,
        num_encoder_layers: Union[int, Tuple[int]] = 4,
        encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
        query_head_dim: Union[int, Tuple[int]] = 24,
        pos_head_dim: Union[int, Tuple[int]] = 4,
        value_head_dim: Union[int, Tuple[int]] = 12,
        num_heads: Union[int, Tuple[int]] = 8,
        feedforward_dim: Union[int, Tuple[int]] = 1536,
        cnn_module_kernel: Union[int, Tuple[int]] = 31,
        pos_dim: int = 192,
        dropout: FloatLike = None,
        warmup_batches: float = 4000.0,
        causal: bool = False,
        chunk_size: Tuple[int] = [-1],
        left_context_frames: Tuple[int] = [-1],
        num_experts: int = 4,
        top_k: int = 2,
        granularity: int = 1,
        num_shared_experts: int = 0,
        moe_type: str = 'middle',
        normalize_moe_score: bool = False
    ) -> None:
        
        super(ZipformerMoe, self).__init__()
        if dropout is None:
            dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1))

        def _to_tuple(x):
            """Converts a single int or a 1-tuple of an int to a tuple with the same length
            as downsampling_factor"""
            if isinstance(x, int):
                x = (x,)
            if len(x) == 1:
                x = x * len(downsampling_factor)
            else:
                assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
            return x

        self.output_downsampling_factor = output_downsampling_factor  # int
        self.downsampling_factor = downsampling_factor  # tuple
        self.encoder_dim = encoder_dim = _to_tuple(encoder_dim)  # tuple
        self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(
            encoder_unmasked_dim
        )  # tuple
        num_encoder_layers = _to_tuple(num_encoder_layers)
        self.num_encoder_layers = num_encoder_layers
        self.query_head_dim = query_head_dim = _to_tuple(query_head_dim)
        self.value_head_dim = value_head_dim = _to_tuple(value_head_dim)
        pos_head_dim = _to_tuple(pos_head_dim)
        self.num_heads = num_heads = _to_tuple(num_heads)
        feedforward_dim = _to_tuple(feedforward_dim)
        self.cnn_module_kernel = cnn_module_kernel = _to_tuple(cnn_module_kernel)

        self.causal = causal
        self.chunk_size = chunk_size
        self.left_context_frames = left_context_frames

        for u, d in zip(encoder_unmasked_dim, encoder_dim):
            assert u <= d
        
        
        # Store the new argument
        self.num_experts = num_experts
        self.top_k = top_k
        self.moe_type = moe_type
        self.granularity = granularity
        self.num_shared_experts = num_shared_experts

        # Replace the encoder layers with MoE layers
        num_encoders = len(downsampling_factor)
        encoders = []
        for i in range(num_encoders):
            encoder_layer = ZipformerMoeEncoderLayer2(
                embed_dim=encoder_dim[i],
                pos_dim=pos_dim,
                num_heads=num_heads[i],
                query_head_dim=query_head_dim[i],
                pos_head_dim=pos_head_dim[i],
                value_head_dim=value_head_dim[i],
                feedforward_dim=feedforward_dim[i],
                dropout=dropout,
                cnn_module_kernel=cnn_module_kernel[i],
                causal=causal,
                num_experts=num_experts,
                top_k=top_k,
                num_shared_experts=num_shared_experts,
                granularity=granularity,
                moe_layers=moe_type,
                normalize_moe_score=normalize_moe_score
            )
                

            # For the segment of the warmup period, we let the Conv2dSubsampling
            # layer learn something.  Then we start to warm up the other encoders.
            encoder = Zipformer2MoeEncoder(
                encoder_layer,
                num_encoder_layers[i],
                pos_dim=pos_dim,
                dropout=dropout,
                warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
                warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
                final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
            )

            if downsampling_factor[i] != 1:
                encoder = DownsampledZipformer2MoeEncoder(
                    encoder,
                    dim=encoder_dim[i],
                    downsample=downsampling_factor[i],
                    dropout=dropout,
                    causal=causal,
                )

            encoders.append(encoder)
        self.encoders = nn.ModuleList(encoders)
        
        self.downsample_output = SimpleDownsample(
            max(encoder_dim),
            downsample=output_downsampling_factor,
            dropout=dropout,
            causal=causal,
        )
    
        
    def forward(
        self,
        x: Tensor,
        x_lens: Tensor,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
          x:
            The input tensor. Its shape is (seq_len, batch_size, feature_dim).
          x_lens:
            A tensor of shape (batch_size,) containing the number of frames in
            `x` before padding.
          src_key_padding_mask:
            The mask for padding, of shape (batch_size, seq_len); True means
            masked position. May be None.
        Returns:
          Return a tuple containing 2 tensors:
            - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim))
            - lengths, a tensor of shape (batch_size,) containing the number
              of frames in `embeddings` before padding.
        """
        outputs = []
        gate_logits_all = ()
        src_key_padding_mask_all = ()
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            feature_masks = [1.0] * len(self.encoder_dim)
        else:
            feature_masks = self.get_feature_masks(x)

        chunk_size, left_context_chunks = self.get_chunk_info()

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            # Not support exporting a model for simulating streaming decoding
            attn_mask = None
        else:
            attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)

        for i, module in enumerate(self.encoders):
            ds = self.downsampling_factor[i]
            x = convert_num_channels(x, self.encoder_dim[i])

            src_key_padding_mask_layer= None if src_key_padding_mask is None else src_key_padding_mask[..., ::ds]

            x, gate_logits = module(
                x,
                chunk_size=chunk_size,
                feature_mask=feature_masks[i],
                src_key_padding_mask=src_key_padding_mask_layer,
                attn_mask=attn_mask,
            )
            outputs.append(x)
            gate_logits_all += gate_logits
            src_key_padding_mask_all += (src_key_padding_mask_layer,) * len(gate_logits)
        # if the last output has the largest dimension, x will be unchanged,
        # it will be the same as outputs[-1].  Otherwise it will be concatenated
        # from different pieces of 'outputs', taking each dimension from the
        # most recent output that has it present.
        x = self._get_full_dim_output(outputs)
        x = self.downsample_output(x)
        # class Downsample has this rounding behavior..
        assert self.output_downsampling_factor == 2, self.output_downsampling_factor
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            lengths = (x_lens + 1) // 2
        else:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                lengths = (x_lens + 1) // 2

        return x, lengths, gate_logits_all, src_key_padding_mask_all
    
        
class ZipformerMoeEncoderLayer(Zipformer2EncoderLayer):
    def __init__(
        self,
        embed_dim: int,
        pos_dim: int,
        num_heads: int,
        query_head_dim: int,
        pos_head_dim: int,
        value_head_dim: int,
        feedforward_dim: int,
        dropout: FloatLike = 0.1,
        cnn_module_kernel: int = 31,
        causal: bool = False,
        num_experts: int = 4,  # New argument for the number of experts
        top_k: int = 2,  # New argument for Top-k gating
        **kwargs,
    ):
        super().__init__(
            embed_dim=embed_dim,
            pos_dim=pos_dim,
            num_heads=num_heads,
            query_head_dim=query_head_dim,
            pos_head_dim=pos_head_dim,
            value_head_dim=value_head_dim,
            feedforward_dim=feedforward_dim,
            dropout=dropout,
            cnn_module_kernel=cnn_module_kernel,
            causal=causal,
            **kwargs,
        )
        
        # Number of experts for MoE
        self.num_experts = num_experts
        self.top_k = top_k

        # Replace feed_forward2 with MoE
        del self.feed_forward2
        self.gate = nn.Linear(embed_dim, num_experts)  # Gating network
        self.experts = nn.ModuleList([
            FeedforwardModule(embed_dim, feedforward_dim, dropout)
            for _ in range(num_experts)
        ])



    def forward(
        self,
        src: Tensor,
        pos_emb: Tensor,
        chunk_size: int = -1,
        attn_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        src_orig = src

        # Dropout rate for non-feedforward submodules
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            attention_skip_rate = 0.0
        else:
            attention_skip_rate = (
                float(self.attention_skip_rate) if self.training else 0.0
            )

        # Attention weights
        attn_weights = self.self_attn_weights(
            src, pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask
        )

        # Feed-forward Layer 1
        src = src + self.feed_forward1(src)

        # Dropout mask for self-attention
        self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)

        selected_attn_weights = attn_weights[0:1]
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            pass
        elif self.training and random.random() < float(self.const_attention_rate):
            # Use constant attention weights to encourage averaging
            selected_attn_weights = selected_attn_weights[0:1]
            selected_attn_weights = (selected_attn_weights > 0.0).to(
                selected_attn_weights.dtype
            )
            selected_attn_weights = selected_attn_weights * (
                1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
            )

        # Non-linear attention
        na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))
        src = src + (
            na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
        )

        # Self-attention 1
        self_attn = self.self_attn1(src, attn_weights)
        src = src + (
            self_attn 
            if self_attn_dropout_mask is None 
            else self_attn * self_attn_dropout_mask
        )

        # Convolution Module 1
        src = src + self.sequence_dropout(
            self.conv_module1(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            float(self.conv_skip_rate) if self.training else 0.0,
        )

        # Gating mechanism
        gate_scores = torch.softmax(self.gate(src), dim=-1)  # Shape: (seq_len, batch_size, num_experts)
        top_k_scores, top_k_indices = torch.topk(gate_scores, self.top_k, dim=-1)  # Top-k experts

        # Prepare storage for final MoE output
        moe_output = torch.zeros_like(src)  # Shape: (seq_len, batch_size, embed_dim)

        # Prepare storage for expert outputs
        expert_outputs = {}  # To store processed outputs and global indices for each expert

        # Process tokens for all experts in batches
        for i, expert in enumerate(self.experts):
            # Create a global mask for tokens routed to this expert (any top-k position)
            mask = (top_k_indices == i).any(dim=-1)  # Shape: (seq_len, batch_size)
            if mask.any():
                selected_tokens = src[mask]  # Shape: (num_tokens_routed, embed_dim)
                processed_tokens = expert(selected_tokens)  # Forward pass through the expert
                expert_outputs[i] = (mask, processed_tokens)

        # Integrate outputs back into MoE output
        for i in range(self.top_k):
            expert_id = top_k_indices[..., i]  # Shape: (seq_len, batch_size)
            expert_weight = top_k_scores[..., i].unsqueeze(-1)  # Shape: (seq_len, batch_size, 1)

            for exp_id, (global_mask, processed_tokens) in expert_outputs.items():
                # Identify tokens routed to this expert at the current top-k level
                current_mask = (expert_id == exp_id)  # Mask for this expert at this top-k level
                if current_mask.any():
                    # current_mask[global_mask] gives a 1D boolean mask selecting the subset 
                    # of elements from src[global_mask] that correspond to src[current_mask].
                    indices_mask = current_mask[global_mask]
                    # Update moe_output
                    moe_output[current_mask] += expert_weight[current_mask] * processed_tokens[indices_mask]

        # ---------------------------MOE-------------------------------------------
        # Apply dropout and residual connection
        src = src + self.sequence_dropout(
            self.balancer_ff2(moe_output), 
            float(self.ff2_skip_rate) if self.training else 0.0
        )

        # Bypass in the middle of the layer
        src = self.bypass_mid(src_orig, src)

        # Self-attention 2
        self_attn = self.self_attn2(src, attn_weights)
        src = src + (
            self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask
        )

        # Convolution Module 2
        src = src + self.sequence_dropout(
            self.conv_module2(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            float(self.conv_skip_rate) if self.training else 0.0,
        )

        # Feed-forward Layer 3
        src = src + self.sequence_dropout(
            self.balancer_ff3(self.feed_forward3(src)),
            float(self.ff3_skip_rate) if self.training else 0.0,
        )

        # Final normalization and whitening
        src = self.balancer1(src)
        src = self.norm(src)
        src = self.bypass(src_orig, src)
        src = self.balancer2(src)
        src = self.whiten(src)

        return src, gate_scores
    
    
class ZipformerMoeEncoderLayer2(Zipformer2EncoderLayer):
    def __init__(
        self,
        embed_dim: int,
        pos_dim: int,
        num_heads: int,
        query_head_dim: int,
        pos_head_dim: int,
        value_head_dim: int,
        feedforward_dim: int,
        dropout: FloatLike = 0.1,
        cnn_module_kernel: int = 31,
        causal: bool = False,
        num_experts: int = 4,  # New argument for the number of experts
        top_k: int = 2,  # New argument for Top-k gating
        granularity: int = 1,
        num_shared_experts: int = 0,
        moe_layers: str = '1,2,3',
        normalize_moe_score: bool = False,
        **kwargs,
    ):
        super().__init__(
            embed_dim=embed_dim,
            pos_dim=pos_dim,
            num_heads=num_heads,
            query_head_dim=query_head_dim,
            pos_head_dim=pos_head_dim,
            value_head_dim=value_head_dim,
            feedforward_dim=feedforward_dim,
            dropout=dropout,
            cnn_module_kernel=cnn_module_kernel,
            causal=causal,
            **kwargs,
        )
        
        # Number of experts for MoE
        self.num_experts = num_experts
        self.top_k = top_k
        self.granularity = granularity
        self.num_shared_experts = num_shared_experts
        self.moe_layers = moe_layers

        # Replace 3 feed_forward with MoE
        del self.feed_forward1
        if '1' in moe_layers:
            self.moe_1 = MoEModel(num_experts, embed_dim, top_k, (feedforward_dim * 3) // 4, dropout, granularity, 
                                  num_shared_experts,
                                  normalized_score=normalize_moe_score)
        else:
            self.moe_1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, dropout)
        
        del self.feed_forward2
        if '2' in moe_layers:
            self.moe_2 = MoEModel(num_experts, embed_dim, top_k, feedforward_dim, dropout, granularity,
                                  num_shared_experts,
                                  normalized_score=normalize_moe_score)
        else:
            self.moe_2 = FeedforwardModule(embed_dim, feedforward_dim, dropout)
            
        del self.feed_forward3
        if '3' in moe_layers:
            self.moe_3 = MoEModel(num_experts, embed_dim, top_k, (feedforward_dim * 5) // 4, dropout, granularity,
                                  num_shared_experts,
                                  normalized_score=normalize_moe_score)
        else:
            self.moe_3 = FeedforwardModule(embed_dim, (feedforward_dim * 5) // 4, dropout)


    def forward(
        self,
        src: Tensor,
        pos_emb: Tensor,
        chunk_size: int = -1,
        attn_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        """
            Pass the input through the encoder layer.
            Args:
                src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
             pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
             chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
           feature_mask: something that broadcasts with src, that we'll multiply `src`
                  by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
             attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                    interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                   True means masked position. May be None.
        src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

            Returns:
               A tensor which has the same shape as src
        """
        src_orig = src
        gate_logits = ()
        # dropout rate for non-feedforward submodules
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            attention_skip_rate = 0.0
        else:
            attention_skip_rate = (
                float(self.attention_skip_rate) if self.training else 0.0
            )

        # attn_weights: (num_heads, batch_size, seq_len, seq_len)
        attn_weights = self.self_attn_weights(
            src,
            pos_emb=pos_emb,
            attn_mask=attn_mask,
            key_padding_mask=src_key_padding_mask,
        )

        if '1' in self.moe_layers:
            moe_output1, gate_logits1 = self.moe_1(src)
            gate_logits += (gate_logits1,)
        else:
            moe_output1 = self.moe_1(src) # its a normal FFN
        src = src + moe_output1

        self_attn_dropout_mask = self.get_sequence_dropout_mask(
            src, attention_skip_rate
        )

        selected_attn_weights = attn_weights[0:1]
        if torch.jit.is_scripting() or torch.jit.is_tracing():
            pass
        elif self.training and random.random() < float(self.const_attention_rate):
            # Make attention weights constant.  The intention is to
            # encourage these modules to do something similar to an
            # averaging-over-time operation.
            # only need the mask, can just use the 1st one and expand later
            selected_attn_weights = selected_attn_weights[0:1]
            selected_attn_weights = (selected_attn_weights > 0.0).to(
                selected_attn_weights.dtype
            )
            selected_attn_weights = selected_attn_weights * (
                1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)
            )

        na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights))

        src = src + (
            na if self_attn_dropout_mask is None else na * self_attn_dropout_mask
        )

        self_attn = self.self_attn1(src, attn_weights)

        src = src + (
            self_attn
            if self_attn_dropout_mask is None
            else self_attn * self_attn_dropout_mask
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            conv_skip_rate = 0.0
        else:
            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.conv_module1(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            conv_skip_rate,
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            ff2_skip_rate = 0.0
        else:
            ff2_skip_rate = float(self.ff2_skip_rate) if self.training else 0.0
            
        if '2' in self.moe_layers:
            moe_output2, gate_logits2 = self.moe_2(src)
            gate_logits += (gate_logits2,)
        else:
            moe_output2 = self.moe_2(src)
        src = src + self.sequence_dropout(
            self.balancer_ff2(moe_output2), ff2_skip_rate
        )

        # bypass in the middle of the layer.
        src = self.bypass_mid(src_orig, src)

        self_attn = self.self_attn2(src, attn_weights)

        src = src + (
            self_attn
            if self_attn_dropout_mask is None
            else self_attn * self_attn_dropout_mask
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            conv_skip_rate = 0.0
        else:
            conv_skip_rate = float(self.conv_skip_rate) if self.training else 0.0
        src = src + self.sequence_dropout(
            self.conv_module2(
                src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask
            ),
            conv_skip_rate,
        )

        if torch.jit.is_scripting() or torch.jit.is_tracing():
            ff3_skip_rate = 0.0
        else:
            ff3_skip_rate = float(self.ff3_skip_rate) if self.training else 0.0
            
        if '3' in self.moe_layers:
            moe_output3, gate_logits3 = self.moe_3(src)
            gate_logits += (gate_logits3,)
        else:
            moe_output3 = self.moe_3(src)
        src = src + self.sequence_dropout(
            self.balancer_ff3(moe_output3), ff3_skip_rate
        )

        src = self.balancer1(src)
        src = self.norm(src)

        src = self.bypass(src_orig, src)

        src = self.balancer2(src)
        src = self.whiten(src)

        return src, gate_logits


class Zipformer2MoeEncoder(Zipformer2Encoder):
    r"""Zipformer2Encoder is a stack of N encoder layers

    Args:
        encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
       pos_dim: the dimension for the relative positional encoding

    Examples::
        >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
        >>> zipformer_encoder = Zipformer2Encoder(encoder_layer, num_layers=6)
        >>> src = torch.rand(10, 32, 512)
        >>> out = zipformer_encoder(src)
    """

    def __init__(
        self,
        encoder_layer: nn.Module,
        num_layers: int,
        pos_dim: int,
        dropout: float,
        warmup_begin: float,
        warmup_end: float,
        initial_layerdrop_rate: float = 0.5,
        final_layerdrop_rate: float = 0.05,
    ) -> None:
        super().__init__(
            encoder_layer=encoder_layer,
            num_layers=num_layers,
            pos_dim=pos_dim,
            dropout=dropout,
            warmup_begin=warmup_begin,
            warmup_end=warmup_end,
            initial_layerdrop_rate=initial_layerdrop_rate,
            final_layerdrop_rate=final_layerdrop_rate,
        )
        self.encoder_pos = CompactRelPositionalEncoding(
            pos_dim, dropout_rate=0.15, length_factor=1.0
        )

        self.layers = nn.ModuleList(
            [copy.deepcopy(encoder_layer) for i in range(num_layers)]
        )
        self.num_layers = num_layers

        assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end)

        delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
        cur_begin = warmup_begin  # interpreted as a training batch index
        for i in range(num_layers):
            cur_end = cur_begin + delta
            self.layers[i].bypass.skip_rate = ScheduledFloat(
                (cur_begin, initial_layerdrop_rate),
                (cur_end, final_layerdrop_rate),
                default=0.0,
            )
            cur_begin = cur_end

    def forward(
        self,
        src: Tensor,
        chunk_size: int = -1,
        feature_mask: Union[Tensor, float] = 1.0,
        attn_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
            feature_mask: something that broadcasts with src, that we'll multiply `src`
               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                 True means masked position. May be None.
            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

        Returns: a Tensor with the same shape as src.
        """
        pos_emb = self.encoder_pos(src)
        output = src
        gate_logits = ()

        if not torch.jit.is_scripting() and not torch.jit.is_tracing():
            output = output * feature_mask

        for i, mod in enumerate(self.layers):
            output, layer_gate_logits = mod(
                output,
                pos_emb,
                chunk_size=chunk_size,
                attn_mask=attn_mask,
                src_key_padding_mask=src_key_padding_mask,
            )
            gate_logits += layer_gate_logits
            if not torch.jit.is_scripting() and not torch.jit.is_tracing():
                output = output * feature_mask

        return output, gate_logits


class DownsampledZipformer2MoeEncoder(DownsampledZipformer2Encoder):
    r"""
    DownsampledZipformer2Encoder is a zipformer encoder evaluated at a reduced frame rate,
    after convolutional downsampling, and then upsampled again at the output, and combined
    with the origin input, so that the output has the same shape as the input.
    """

    def __init__(
        self,
        encoder: nn.Module,
        dim: int,
        downsample: int,
        dropout: FloatLike,
        causal: bool,
    ):
        # Initialize the parent class
        super().__init__(
            encoder=encoder,
            dim=dim,
            downsample=downsample,
            dropout=dropout,
            causal=causal,
        )
        

    def forward(
        self,
        src: Tensor,
        chunk_size: int = -1,
        feature_mask: Union[Tensor, float] = 1.0,
        attn_mask: Optional[Tensor] = None,
        src_key_padding_mask: Optional[Tensor] = None,
    ) -> Tensor:
        r"""Downsample, go through encoder, upsample.

        Args:
            src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
            feature_mask: something that broadcasts with src, that we'll multiply `src`
               by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
            attn_mask: the attention mask, of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len),
                 interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
                 True means masked position. May be None.
            src_key_padding_mask:  the mask for padding, of shape (batch_size, seq_len); True means
                 masked position.  May be None.

        Returns: a Tensor with the same shape as src.
        """
        src_orig = src
        src = self.downsample(src)
        ds = self.downsample_factor
        if attn_mask is not None:
            attn_mask = attn_mask[::ds, ::ds]

        src, gate_logits = self.encoder(
            src,
            chunk_size=chunk_size // ds,
            feature_mask=feature_mask,
            attn_mask=attn_mask,
            src_key_padding_mask=src_key_padding_mask,
        )
        src = self.upsample(src)
        # remove any extra frames that are not a multiple of downsample_factor
        src = src[: src_orig.shape[0]]

        return self.out_combiner(src_orig, src), gate_logits
    
    

class MoEModel(nn.Module):
    def __init__(self, num_experts, embed_dim, top_k, hidden_dim, dropout, granularity=1, num_shared_experts=0, normalized_score=True):
        super(MoEModel, self).__init__()
        
        # Number of experts and top-k value
        self.num_experts = num_experts # number of routed experts
        self.top_k = top_k
        self.granularity = granularity
        assert num_shared_experts <= top_k
        self.num_shared_experts = num_shared_experts
        self.normalized_score = normalized_score
        
        if self.num_experts > 0:
            # Define the gating mechanism (usually a dense layer)
            self.gate = nn.Linear(embed_dim, self.num_experts, bias=False)
            
            # Define the individual experts
            self.experts = nn.ModuleList([
                FeedforwardModule(embed_dim, hidden_dim // granularity, dropout)
                for _ in range(self.num_experts)
            ])
        
        if self.num_shared_experts > 0:
            # this is the shared experts
            self.shared_experts = nn.ModuleList([
                FeedforwardModule(embed_dim, hidden_dim // granularity, dropout)
                for _ in range(self.num_shared_experts)
            ])
        
    def forward(self, src):                   
                 
        L, B, D = src.size()
        src = src.reshape(-1, D)
        # Prepare storage for final MoE output
        moe_output = torch.zeros_like(src)  # Shape: (L*B, embed_dim)
        
        if self.num_experts > 0:
            ## softmax-then-topk
            gate_logits = self.gate(src)
            gate_logits = torch.clamp(gate_logits, min=-100.0, max=100.0) # deal with overflow/underflow
            routing_weights = torch.softmax(gate_logits, dim=-1, dtype=torch.float)  # Shape: (seq_len * batch_size, num_experts)
            routing_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)  # Top-k experts
            if self.normalized_score:
                routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # if normalized, then equivalent to Topk first
            
            for i, expert in enumerate(self.experts):
                mask = top_k_indices == i
                batch_idx, ith_expert = torch.where(mask)
                moe_output[batch_idx] += routing_weights[batch_idx, ith_expert, None] * expert(src[batch_idx])
        else:
            gate_logits = None
            
        # Handle shared experts (if any)
        if self.num_shared_experts > 0:
            for i, shared_expert in enumerate(self.shared_experts):
                moe_output += shared_expert(src)
        
        return moe_output.view(L, B, -1), gate_logits

