from transformers import PretrainedConfig
from typing import List, Tuple, Optional


class MiMoEConfig(PretrainedConfig):
    model_type = "mimoe"

    def __init__(
        self, 
        pretrain_loss: Optional[str],
        input_dim: int,
        batch_size: int=128,
        max_seq_len: int=50,
        num_heads: int=4,
        dropout_rate: float=0.1,
        routing_temperature: float=0.1,
        num_layers: int=5,
        hidden_dim: int=300,
        granularity: int=2,
        expansion_ratio: int=8,
        position_embedding: str="rope", # sinusoidal, learned, rope
        use_router_residual: bool=True,
        router: str="basic_router", # basic_router, 2layer_router
        buffer: str="topk", # topk, static_buffer, threshold_buffer
        buffer_ratio: float=1.0, # used for static_buffer
        topk_threshold: float=1.0, # used for threshold_buffer
        expert_act_fn: str="relu2",
        infonce_temperature: float=0.1,
        **kwargs
    ):
        super().__init__(
            **kwargs
        )
        # --- Model configuration ---
        self.input_dim = input_dim
        self.batch_size = batch_size
        self.hidden_dim = hidden_dim
        self.max_seq_len = max_seq_len
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.dropout_rate = dropout_rate
        self.position_embedding = position_embedding
        
        # --- Router ---
        self.router = router
        self.routing_temperature = routing_temperature
        self.use_router_residual = use_router_residual
        
        # --- Buffer ---
        self.buffer = buffer
        self.buffer_ratio = buffer_ratio
        self.topk_threshold = topk_threshold
        
        # --- Experts ---
        self.granularity = granularity
        self.expansion_ratio = expansion_ratio
        self.expert_act_fn = expert_act_fn
        
        # --- Losses ---
        self.critic_dim = hidden_dim
        self.infonce_temperature = infonce_temperature
        self.pretrain_loss = pretrain_loss