from typing import Generator, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig, Transformer
from src.model.base import make_attention_mask, Attention, MLP


class MoE(nn.Module):
    """
    Gateless MoE. Dispatch by per-sample multi-hot mask over experts.
    Row-wise: core expert is always active (we forcibly OR it on).
    """
    def __init__(
        self,
        embed_dim: int,
        num_experts: int,
        mlp_dim: int,
        aux_dim: int,
    ) -> None:

        super().__init__()

        self.num_experts = num_experts
        self.experts = nn.ModuleList(
            [MLP(embed_dim, mlp_dim)] +        # core (idx 0)
            [MLP(embed_dim, aux_dim)           # aux  (idx 1..E-1)
             for _ in range(num_experts - 1)]
        )

    def forward(
        self, x: torch.Tensor, 
        select_mask: torch.Tensor,
        exp_idx: int,
    ) -> torch.Tensor:
        """
        Runs all the experts in parallel. More efficient approach for single GPU with compiled model.
        
        x:           (B, T, E) for batch, sequence length, and embedding dimension
        select_mask: (K,) boolean, multi-hot selection of experts.
                     Core (idx 0) is always included.
        returns:     (B, T, E)
        """

        K = len(self.experts)
        if K == 1:
            exp_idx = 0

        # Core expert (idx 0)
        y = self.experts[0](x)

        if exp_idx == 0:
            return y

        elif exp_idx > 0:
            return y + self.experts[exp_idx](x)

        else:

            # Get stacked aux weights
            aux_fc_w = torch.stack([e.c_fc.weight for e in self.experts[1:]], dim=0)
            aux_fc_b = torch.stack([e.c_fc.bias for e in self.experts[1:]], dim=0)
            aux_proj_w = torch.stack([e.c_proj.weight for e in self.experts[1:]], dim=0)
            aux_proj_b = torch.stack([e.c_proj.bias for e in self.experts[1:]], dim=0)
            
            # Batched forward: x @ W^T for all aux experts at once
            # einsum: (B,T,E) with (K-1,H,E) -> (B,T,K-1,H)
            aux_hidden = torch.einsum('bte,khe->btkh', x, aux_fc_w) + aux_fc_b.view(1, 1, K-1, -1)
            aux_hidden = F.gelu(aux_hidden, approximate="tanh")
            
            # Second layer: (B,T,K-1,H) with (K-1,E,H) -> (B,T,K-1,E)
            aux_output = torch.einsum('btkh,keh->btke', aux_hidden, aux_proj_w) + aux_proj_b.view(1, 1, K-1, -1)

            # Apply select mask to aux outputs
            aux_select_mask = select_mask[1:]
            aux_output = aux_output * aux_select_mask.view(1, 1, K-1, 1)
            aux_output = aux_output.sum(dim=2)  # (B, T, E)
            
            # Sum MoE outputs with core output
            return y + aux_output  # (B, T, E)

    def get_params(self, expert_idx: int) -> Generator[torch.Tensor, None, None]:
        assert 0 <= expert_idx < self.num_experts, f"Expert index {expert_idx} out of range"
        yield from self.experts[expert_idx].parameters()


class Block(nn.Module):
    def __init__(
        self,
        config: ModelConfig,
        num_experts: int,
        mlp_dim: int,
        aux_dim: int,
    ) -> None:

        super().__init__()

        self.attn = Attention(
            num_heads=config.num_heads,
            embed_dim=config.embed_dim,
            num_key_value=config.num_key_value,
            bias=config.attn_bias,
        )
    
        self.moe = MoE(
            embed_dim=config.embed_dim,
            mlp_dim=mlp_dim,
            aux_dim=aux_dim,
            num_experts=num_experts,
        )

        self.norm_1 = nn.RMSNorm(config.embed_dim)
        self.norm_2 = nn.RMSNorm(config.embed_dim)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        select_mask: torch.Tensor, # (K,) multi-hot mask
        exp_idx: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        x = x + self.attn(self.norm_1(x), attn_mask=attn_mask)
        x = x + self.moe(self.norm_2(x), select_mask=select_mask, exp_idx=exp_idx)
        return x


class MoETransformer(Transformer):
    def __init__(
        self,
        config: ModelConfig,
        mlp_dim: int,
        aux_dim: int,
    ) -> None:
        
        super().__init__(config)

        self.model_type = "routed"
        self.labels = ["core"] + config.aux_labels

        self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
        self.norm = nn.RMSNorm(config.embed_dim)
        self.unembed = nn.Linear(config.embed_dim, config.vocab_size, bias=True)
        
        self.blocks = []
        for idx in range(config.num_layers):
            num_experts = len(self.labels) if idx in config.target_layers else 1
            self.blocks.append(Block(
                config=config,
                num_experts=num_experts,
                mlp_dim=mlp_dim,
                aux_dim=aux_dim,
            ))
        self.blocks = nn.ModuleList(self.blocks)

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, "bias") and getattr(module, "bias") is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.RMSNorm):
            module.weight.data.fill_(1.0)

    def get_params(self, label: str) -> Generator[torch.Tensor, None, None]:

        assert label in self.labels, f"Label {label} not found in {self.labels}"

        if label == "core":
            yield from self.embed.parameters()
            yield from self.norm.parameters()
            yield from self.unembed.parameters()

        e_idx = self.labels.index(label)

        for block in self.blocks:
            
            if e_idx < block.moe.num_experts:
                yield from block.moe.experts[e_idx].parameters()

            if label == "core":
                yield from block.attn.parameters()
                yield from block.norm_1.parameters()
                yield from block.norm_2.parameters()


    def forward(
        self,
        tokens: torch.Tensor, #(B, T)
        targets: Optional[torch.Tensor], #(B, T)
        select_mask: torch.Tensor, #(K,) multi-hot mask
        optimize: bool = True,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        
        def run_stack(
            tok: torch.Tensor, 
            select_mask: torch.Tensor,
            exp_idx: int,
        ) -> torch.Tensor:

            attn_mask = make_attention_mask(tok, self.config.eos_token_id)
            h = self.embed(tok)
            for block in self.blocks:
                h = block(h, attn_mask=attn_mask, select_mask=select_mask, exp_idx=exp_idx)
            h = self.norm(h)
            return self.unembed(h)

        select_mask[0] = True #core is always on
        exp_idx = -1
        if optimize and select_mask.sum() <= 2:
            exp_idx = select_mask.nonzero()[-1].item() #rightmost expert

        logits = run_stack(tokens, select_mask=select_mask, exp_idx=exp_idx)  # (B, T, V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.reshape(-1),
                ignore_index=-1000,
                reduction="mean",
            )

        return logits, loss