import torch
from torch import nn
from torch import Tensor
from typing import List, Tuple
from models.config import MiMoEConfig

from models.registry import register_buffer



@register_buffer("topk")
class TopKBuffer(nn.Module):
    def __init__(self, config: MiMoEConfig):
        super().__init__()
        self.moe_top_k = config.granularity
        self.num_experts = config.granularity * config.expansion_ratio
    
    def forward(
        self,
        x: Tensor,          # (B, T, D)
        scores_k: Tensor,   # (B, T, k)
        indices_k: Tensor   # (B, T, k)
    ) -> Tuple[List[Tensor], List[Tensor], Tensor, float]:
        B, T, D = x.shape
        k = self.moe_top_k
        N = B * T
        device = x.device
        
        flat_x = x.view(N, D)
        flat_indices = indices_k.view(-1)   # (N*k,)
        flat_scores  = scores_k.view(-1)    # (N*k,)
        token_idx = torch.arange(N, device=device).unsqueeze(1).expand(N, k).reshape(-1)

        expert_inputs: List[Tensor] = []
        expert_scores: List[Tensor] = []
        for e in range(self.num_experts):
            mask_e = (flat_indices == e)
            if mask_e.any():
                expert_inputs.append(flat_x[token_idx[mask_e]])
                expert_scores.append(flat_scores[mask_e])
            else:
                expert_inputs.append(torch.empty(0, D, device=device))
                expert_scores.append(torch.empty(0, device=device))

        # TopK routing → 항상 True
        dispatch_mask = torch.ones_like(indices_k, dtype=torch.bool)
        return expert_inputs, expert_scores, dispatch_mask, 1.0



@register_buffer("static_buffer")
class StaticBuffer(TopKBuffer):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        self.buffer_ratio = config.buffer_ratio
        self.batch_size = config.batch_size
        self.max_seq_len = config.max_seq_len
        self.buffer_capacity = self.calculate_buffer_capacity()
    
    def calculate_buffer_capacity(self):
        total_pairs = self.moe_top_k * self.batch_size * self.max_seq_len
        return int(total_pairs * self.buffer_ratio / self.num_experts)
    
    def forward(
        self,
        x: Tensor,          # (B, T, D)
        scores_k: Tensor,   # (B, T, k)
        indices_k: Tensor   # (B, T, k)
    ) -> Tuple[List[Tensor], List[Tensor], Tensor, float]:
        B, T, D = x.shape
        k = self.moe_top_k
        N = B * T
        device = x.device
        
        flat_x = x.view(N, D)
        flat_indices = indices_k.view(-1) # (N*k,)
        flat_scores = scores_k.view(-1) # (N*k,)
        token_idx    = torch.arange(N, device=device).unsqueeze(1).expand(N, k).reshape(-1)

        dispatch_mask = torch.zeros_like(flat_scores, dtype=torch.bool) # (N*k,)
        expert_inputs: List[Tensor] = []
        expert_scores: List[Tensor] = []
        for e in range(self.num_experts):
            mask_e = (flat_indices == e)
            if not mask_e.any():
                expert_inputs.append(torch.empty(0, D, device=device))
                expert_scores.append(torch.empty(0, device=device))
                continue
        
            pos_e = mask_e.nonzero(as_tuple=False).squeeze(1) # (M,)
            scores_e = flat_scores[pos_e] # (M,)
            cap = min(self.buffer_capacity, scores_e.numel())
            
            top_idx = torch.topk(scores_e, k=cap, largest=True).indices
            chosen_pos = pos_e[top_idx]
            
            dispatch_mask[chosen_pos] = True
            expert_inputs.append(flat_x[token_idx[chosen_pos]])
            expert_scores.append(scores_e[top_idx])
        
        dispatch_mask = dispatch_mask.view(B, T, k)
        return expert_inputs, expert_scores, dispatch_mask, self.buffer_ratio



@register_buffer("threshold_buffer")
class ThresholdBuffer(TopKBuffer):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        self.topk_threshold = config.topk_threshold
    
    def forward(
        self,
        x: Tensor,          # (B, T, D)
        scores_k: Tensor,   # (B, T, k)
        indices_k: Tensor   # (B, T, k)
    ) -> Tuple[List[Tensor], List[Tensor], Tensor, float]:
        B, T, D = x.shape
        k = self.moe_top_k
        N = B * T
        device = x.device
        
        flat_x = x.view(N, D)              # (N, D)
        flat_indices = indices_k.view(-1)  # (N*k,)
        flat_scores = scores_k.view(-1)    # (N*k,)
        token_idx = torch.arange(N, device=device).unsqueeze(1).expand(N, k).reshape(-1) # (N*k,)
        
        # --- thresholding ---
        total_candidates = flat_scores.numel()
        cap = int(total_candidates * self.topk_threshold)
        cap = max(0, min(cap, total_candidates))
        if cap == 0:
            chosen_mask = torch.zeros_like(flat_scores, dtype=torch.bool)
        else:
            threshold_val = torch.topk(flat_scores, k=cap, largest=True).values.min()
            chosen_mask = flat_scores >= threshold_val
        
        chosen_indices = flat_indices[chosen_mask] # (M,)
        chosen_tokens = token_idx[chosen_mask] # (M,)
        chosen_scores = flat_scores[chosen_mask] # (M,)
        
        # --- Grouping ---
        sorted_expert, order = torch.sort(chosen_indices)
        sorted_tokens = chosen_tokens[order]
        sorted_scores = chosen_scores[order]

        split_sizes = torch.bincount(sorted_expert, minlength=self.num_experts)
        expert_inputs = torch.split(flat_x[sorted_tokens], split_sizes.tolist())
        expert_scores = torch.split(sorted_scores, split_sizes.tolist())
        
        # --- dispatch_mask ---
        dispatch_mask = torch.zeros_like(flat_scores, dtype=torch.bool)  # (N*k,)
        dispatch_mask[chosen_mask] = True
        dispatch_mask = dispatch_mask.view(B, T, k)

        return list(expert_inputs), list(expert_scores), dispatch_mask, self.topk_threshold



@register_buffer("dynamic_buffer")
class DynamicBuffer(StaticBuffer):
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        ...
    
    def forward(
        self,
        x: Tensor,          # (B, T, D)
        scores_k: Tensor,   # (B, T, k)
        indices_k: Tensor   # (B, T, k)
    ) -> Tuple[List[Tensor], List[Tensor], Tensor, float]:
        
        raise NotImplementedError("Dynamic buffer routing is not implemented yet.")