import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F
from typing import Tuple

from models.registry import register_router
from models.config import MiMoEConfig


@register_router("basic_router")
class BasicRouter(nn.Module):
    def __init__(
        self, 
        config: MiMoEConfig
    ):
        super().__init__()
        self.moe_top_k = config.granularity
        self.use_router_residual = config.use_router_residual
        self.routing_temperature = config.routing_temperature
        self.num_experts = config.granularity * config.expansion_ratio
        self.proj = nn.Linear(config.hidden_dim, self.num_experts, bias=False)
        self.residual_proj = nn.Linear(self.num_experts, self.num_experts, bias=False) if self.use_router_residual else None
        
    def forward(
        self,
        x: Tensor,
        router_residual: Tensor
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
        logits = self.proj(x)
        if self.use_router_residual:
            residual = self.residual_proj(router_residual)
            logits = logits + residual
        scores = F.softmax(logits / self.routing_temperature, dim=-1)  # [B, T, num_experts]
        scores_k, indices_k = torch.topk(scores, k=self.moe_top_k, dim=-1)
        return scores_k, indices_k, logits, scores


@register_router("2layer_router")
class TwoLayersRouter(BasicRouter):
    def __init__(
        self, 
        config: MiMoEConfig
    ):
        super().__init__(config)
        self.proj = nn.Sequential(
            nn.Linear(config.hidden_dim, config.hidden_dim * 2, bias=False),
            nn.ReLU(),
            nn.Linear(config.hidden_dim * 2, self.num_experts, bias=False)
        )