from typing import Generator, Optional, Tuple
from mpmath.libmp.libelefun import k
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.model.config import ModelConfig, Transformer
from src.model.base import make_attention_mask, Attention, MLP


class DEMixLayer(nn.Module):
    """
    Gateless MoE. Dispatch by per-sample multi-hot mask over experts.
    Row-wise: core expert is always active (we forcibly OR it on).
    """
    def __init__(
        self,
        embed_dim: int,
        num_experts: int,
        mlp_dim: int,
        aux_dim: int,
    ) -> None:

        super().__init__()

        self.num_experts = num_experts
        self.experts = nn.ModuleList(
            [MLP(embed_dim, mlp_dim)] +        # core (idx 0)
            [MLP(embed_dim, aux_dim)           # aux  (idx 1..E-1)
             for _ in range(num_experts - 1)]
        )
        

    def forward(self, x: torch.Tensor, select_mask: torch.Tensor, exp_idx: int) -> torch.Tensor:
        """
        Runs all the experts in parallel. More efficient approach for single GPU with compiled model.
        
        x:           (B, T, E) for batch, sequence length, and embedding dimension
        select_mask: (K,) boolean, multi-hot selection of experts.
        returns:     (B, T, E)
        """

        K = len(self.experts)
        if K == 1:
            exp_idx = 0

        if exp_idx != -1:
            return self.experts[exp_idx](x)
            
        else:

            # Core expert (idx 0)
            core_output = self.experts[0](x)

            # Get stacked aux weights
            aux_fc_w = torch.stack([e.c_fc.weight for e in self.experts[1:]], dim=0)
            aux_fc_b = torch.stack([e.c_fc.bias for e in self.experts[1:]], dim=0)
            aux_proj_w = torch.stack([e.c_proj.weight for e in self.experts[1:]], dim=0)
            aux_proj_b = torch.stack([e.c_proj.bias for e in self.experts[1:]], dim=0)
            
            # Batched forward: x @ W^T for all aux experts at once
            # einsum: (B,T,E) with (K-1,H,E) -> (B,T,K-1,H)
            aux_hidden = torch.einsum('bte,khe->btkh', x, aux_fc_w) + aux_fc_b.view(1, 1, K-1, -1)
            aux_hidden = F.gelu(aux_hidden, approximate="tanh")
            
            # Second layer: (B,T,K-1,H) with (K-1,E,H) -> (B,T,K-1,E)
            aux_output = torch.einsum('btkh,keh->btke', aux_hidden, aux_proj_w) + aux_proj_b.view(1, 1, K-1, -1)

            # Combine core and aux: stack along expert dimension
            acts = torch.cat([core_output.unsqueeze(2), aux_output], dim=2)  # (B, T, K, E)
            
            # Apply select mask to experts
            acts = acts * select_mask.view(1, 1, K, 1)
            
            # Sum over the active experts
            y = acts.sum(dim=2)
            
            return y

    def get_params(self, expert_idx: int) -> Generator[torch.Tensor, None, None]:
        assert 0 <= expert_idx < self.num_experts, f"Expert index {expert_idx} out of range"
        yield from self.experts[expert_idx].parameters()


class Block(nn.Module):
    def __init__(
        self, 
        config: ModelConfig,
        num_experts: int, 
        mlp_dim: int,
        aux_dim: int,
    ) -> None:

        super().__init__()
        self.attn = Attention(
            num_heads=config.num_heads,
            embed_dim=config.embed_dim,
            num_key_value=config.num_key_value,
            bias=config.attn_bias,
        )

        self.moe = DEMixLayer(
            embed_dim=config.embed_dim,
            num_experts=num_experts,
            mlp_dim=mlp_dim,
            aux_dim=aux_dim,
        )

        self.norm_1 = nn.RMSNorm(config.embed_dim)
        self.norm_2 = nn.RMSNorm(config.embed_dim)

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: torch.Tensor,
        select_mask: torch.Tensor, # (K,) multi-hot mask
        exp_idx: int,
    ) -> Tuple[torch.Tensor, torch.Tensor]:

        x = x + self.attn(self.norm_1(x), attn_mask=attn_mask)
        x = x + self.moe(self.norm_2(x), select_mask=select_mask, exp_idx=exp_idx)
        return x


class DemixTransformer(Transformer):
    def __init__(
        self, config: ModelConfig,
        mlp_dim: int,
        aux_dim: int,
    ) -> None:

        super().__init__(config)

        self.model_type = "routed"
        self.labels = ["core"] + config.aux_labels
         
        self.embed = nn.Embedding(config.vocab_size, config.embed_dim)
        self.norm = nn.RMSNorm(config.embed_dim)
        self.unembed = nn.Linear(config.embed_dim, config.vocab_size, bias=True)
        
        self.blocks = []
        for idx in range(config.num_layers):
            num_experts = len(self.labels) if idx in config.target_layers else 1
            self.blocks.append(Block(
                config=config,
                num_experts=num_experts,
                mlp_dim=mlp_dim,
                aux_dim=aux_dim,
            ))
        self.blocks = nn.ModuleList(self.blocks)

        self.apply(self._init_weights)

    def _init_weights(self, module: nn.Module) -> None:
        if isinstance(module, (nn.Linear, nn.Embedding)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if hasattr(module, "bias") and getattr(module, "bias") is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.RMSNorm):
            module.weight.data.fill_(1.0)

    def get_params(self, label: str) -> Generator[torch.Tensor, None, None]:

        all_labels = self.labels + ["SHARED"]
        assert label in all_labels, f"Label {label} not found in {all_labels}"

        # Shared stack always included for "core"
        if label == "SHARED":
            yield from self.embed.parameters()
            yield from self.norm.parameters()
            yield from self.unembed.parameters()

        e_idx = -1
        if label in self.labels:
            e_idx = self.labels.index(label)

        for block in self.blocks:

            if label == "SHARED":
                yield from block.attn.parameters()
                yield from block.norm_1.parameters()
                yield from block.norm_2.parameters()
            
            if -1 < e_idx < block.moe.num_experts:
                yield from block.moe.experts[e_idx].parameters()


    def forward(
        self,
        tokens: torch.Tensor, #(B, T)
        targets: Optional[torch.Tensor],
        select_mask: torch.Tensor,
        optimize: bool = True,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        
        def run_stack(
            tok: torch.Tensor, 
            select_mask: torch.Tensor,
            exp_idx: int,
        ) -> torch.Tensor:

            attn_mask = make_attention_mask(tok, self.config.eos_token_id)
            h = self.embed(tok)
            for block in self.blocks:
                h = block(h, attn_mask=attn_mask, select_mask=select_mask, exp_idx=exp_idx)
            h = self.norm(h)
            return self.unembed(h)

        assert sum(select_mask) == 1, "DEMix: Only one expert should be active"
        exp_idx = -1
        if optimize:
            exp_idx = select_mask.nonzero()[0].item()
            
        logits = run_stack(tokens, select_mask=select_mask, exp_idx=exp_idx)  # (B, T, V)

        loss = None
        if targets is not None:
            #calculate CE loss
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.reshape(-1),
                ignore_index=-1000,
                reduction="mean",
            )

        return logits, loss