import torch
from torch import nn
from torch import Tensor
from statistics import mean
from transformers.modeling_utils import PreTrainedModel

from models.config import MiMoEConfig
from models.output import MiMoEOutput
from models.layer import MiMoELayer
from models.buffer import TopKBuffer, StaticBuffer, ThresholdBuffer
from models.registry import get_position_embedding


modality_map = {
    "vision": 0,
    "audio": 1,
    "text": 2
}


class MiMoE(PreTrainedModel):
    
    config_class = MiMoEConfig
    
    def __init__(self, config: MiMoEConfig):
        super().__init__(config)
        self.num_experts = config.granularity * config.expansion_ratio
        
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim))
        self.projection_layer = nn.Linear(config.input_dim, config.hidden_dim)
        self.position_embedding = get_position_embedding(config)
        self.layers = nn.ModuleList([MiMoELayer(config) for _ in range(config.num_layers)])
    
    def forward(
        self, 
        x: Tensor
    ) -> MiMoEOutput:        
        B, T, _ = x.shape
        device = x.device
        
        cls_token = self.cls_token.expand(B, -1, -1).to(device)
        router_residual = torch.zeros(B, T + 1, self.num_experts).to(device)
        
        feature = self.projection_layer(x)
        feature = torch.cat((cls_token, feature), dim=1)
        feature = self.position_embedding(feature)
        
        all_outputs = []
        for layer in self.layers:
            out = layer(feature, router_residual)
            feature, router_residual = out.feature, out.router_residual
            all_outputs.append(out)
            
        return MiMoEOutput(
            last_hidden_state = feature,
            router_scores = [o.scores for o in all_outputs],
            router_logits = [o.router_residual for o in all_outputs],
            buffer_ratios = [o.buffer_ratio for o in all_outputs],
            routing_rate = mean(o.routing_rate for o in all_outputs)
        )
    
    def set_training_mode(
        self,
        router_finetune: str, # [frozen, trainable, cls_token]  
        buffer: str, # [topk, static_buffer, threshold_buffer]
        buffer_ratio: float = 1.0, # used for static_buffer
        topk_threshold: float = 1.0, # used for threshold_buffer
    ):
        for parameter in self.parameters():
            parameter.requires_grad = False
        
        # --- Router finetuning mode ---
        print(f"🔧 Setting router to *{router_finetune}* mode.")
        if router_finetune == "frozen":
            pass
        elif router_finetune == "cls_token":
            self.cls_token.requires_grad = True
        elif router_finetune == "trainable":
            self._enable_trainable_router()
            
        # --- Routing strategy ---
        print(f"📦 Using *{buffer}* routing strategy.")
        if buffer == "topk":
            for layer in self.layers:
                layer.buffer = TopKBuffer(self.config)
        elif buffer in ["static_buffer", "threshold_buffer"]:
            self._enable_buffer(buffer, buffer_ratio, topk_threshold)
        
        total_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total_parameters = sum(p.numel() for p in self.parameters())
        print(f"✅ Total Trainable Parameters after finetuning mode: {total_trainable} / {total_parameters}")
    
    def _enable_trainable_router(self):
        for layer in self.layers:
            for _, param in layer.router.proj.named_parameters():
                param.requires_grad = True
            if layer.router.use_router_residual and layer.router.residual_proj is not None:
                for _, param in layer.router.residual_proj.named_parameters():
                    param.requires_grad = True    
                    
    def _enable_buffer(
        self,
        buffer: str,  # [static_buffer, threshold_buffer]
        buffer_ratio: float = 1.0,  # used for static_buffer
        topk_threshold: float = 1.0, # used for threshold_buffer
    ):  
        config = self.config
        config.buffer = buffer
        config.buffer_ratio = buffer_ratio
        config.topk_threshold = topk_threshold
        buffer = StaticBuffer if buffer == "static_buffer" else ThresholdBuffer
        for layer in self.layers:
            layer.use_buffer = True
            layer.buffer = buffer(config)
    
    
    def set_inference_mode(
        self,
        buffer: str="static_buffer",  # [topk, static_buffer, threshold_buffer]
        buffer_ratio: float = 1.0,  # used for static_buffer
        topk_threshold: float = 1.0, # used for threshold_buffer
    ):
        for parameter in self.parameters():
            parameter.requires_grad = False
            
        if buffer in ["static_buffer", "threshold_buffer"]:
            self._enable_buffer(buffer, buffer_ratio, topk_threshold)