import torch
from torch import nn
from torch import Tensor
from typing import List, Tuple
from dataclasses import dataclass

from models.attention import Attention, PreNorm, RoPEAttention
from models.config import MiMoEConfig
from models.expert import Expert
from models.registry import get_router, get_buffer


@dataclass
class MiMoELayerOutput:
    feature: Tensor
    router_residual: Tensor
    scores: Tensor
    buffer_ratio: float
    routing_rate: float


class MiMoELayer(nn.Module):
    def __init__(
        self,
        config: MiMoEConfig,
    ):
        super().__init__()
        self.moe_top_k = config.granularity
        self.num_experts = config.granularity * config.expansion_ratio
        
        attn_cls = RoPEAttention if config.position_embedding == "rope" else Attention
        self.attention = PreNorm(config.hidden_dim, attn_cls(config))
        self.router = get_router(config)
        self.buffer = get_buffer(config)
        self.experts = nn.ModuleList([PreNorm(config.hidden_dim, Expert(config)) for _ in range(self.num_experts)])
    
    
    def forward(
        self,
        x: Tensor,
        router_residual: Tensor
    ) -> MiMoELayerOutput:
        B, T, D = x.shape
        
        x = x + self.attention(x)
        scores_k, indices_k, router_residual, scores = self.router(x, router_residual)    
        expert_inputs, expert_scores, dispatch_mask, buffer_ratio = self.buffer(x, scores_k, indices_k)
        expert_outputs = self._run_experts(
            expert_inputs, expert_scores, dispatch_mask, indices_k, (B, T, D)
        )
        x = x + expert_outputs
        routing_rate = dispatch_mask.float().mean().item()
        return MiMoELayerOutput(
            feature=x,
            router_residual=router_residual,
            scores=scores,
            buffer_ratio=buffer_ratio,
            routing_rate=routing_rate
        )
        
    
    def _run_experts(
        self,
        expert_inputs: List[Tensor],
        expert_scores: List[Tensor], 
        dispatch_mask: Tensor, # (B, T, K)
        indices_k: Tensor, # (B, T, K)
        shape: Tuple[int, int, int] # (B, T, D)
    ) -> Tensor:
        B, T, D = shape
        k = self.moe_top_k
        device = dispatch_mask.device
        
        flat_output = torch.zeros(B * T * k, D, device=device)
        token_indices = dispatch_mask.view(-1).nonzero(as_tuple=False).squeeze(-1) # (B * T * K,)
        expert_ids = indices_k.view(-1)[dispatch_mask.view(-1)]
        for i, expert in enumerate(self.experts):
            mask_i = (expert_ids == i)
            x_i = expert_inputs[i]
            if not mask_i.any() or x_i.numel() == 0:
                continue
            s_i = expert_scores[i].unsqueeze(-1)
            out_i = expert(x_i) * s_i
            flat_output[token_indices[mask_i]] = out_i
        return flat_output.view(B, T, k, D).sum(dim=2) 