# moe_router.py
# Minimal Mixture-of-Experts (MoE) router that routes context vectors to expert modules.
# Router produces soft gating weights for each expert; the experts should be callable modules.

from typing import List, Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


class MoERouter(nn.Module):
    """
    Router that maps context -> gating weights over experts.
    Then merges expert outputs using gating weights (soft routing).
    """

    def __init__(self, context_dim: int = 256, n_experts: int = 4, hidden: int = 128):
        super().__init__()
        self.n_experts = n_experts
        self.gate = nn.Sequential(
            nn.Linear(context_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, n_experts)
        )

    def forward(self, context: torch.Tensor, experts: List[Callable[[torch.Tensor], torch.Tensor]]) -> torch.Tensor:
        """
        context: (B, context_dim)
        experts: list of callables/modules of length n_experts that accept context and return (B, D_out)
        Returns: combined output (B, D_out)
        """
        if context.dim() == 1:
            context = context.unsqueeze(0)
        logits = self.gate(context)  # (B, n_experts)
        weights = F.softmax(logits, dim=-1)  # (B, n_experts)

        # collect expert outputs
        expert_outs = []
        for i, expert in enumerate(experts):
            out = expert(context)  # expect (B, D_out)
            expert_outs.append(out.unsqueeze(1))  # (B, 1, D_out)
        stacked = torch.cat(expert_outs, dim=1)  # (B, n_experts, D_out)
        # expand weights for multiplication
        w = weights.unsqueeze(-1)  # (B, n_experts, 1)
        combined = (w * stacked).sum(dim=1)  # (B, D_out)
        return combined, weights
