from torch import nn
from torch import Tensor
from transformers.activations import ACT2FN

from models.config import MiMoEConfig


class Expert(nn.Module):
    def __init__(
        self,
        config: MiMoEConfig,
    ):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.granularity = config.granularity
        self.expert_dim = int(4 * self.hidden_dim / self.granularity)
        
        self.up_proj = nn.Linear(self.hidden_dim, self.expert_dim)
        self.act_fn = ACT2FN[config.expert_act_fn]
        self.down_proj = nn.Linear(self.expert_dim, self.hidden_dim)
        
    def forward(
        self, 
        x: Tensor
    ) -> Tensor:
        return self.down_proj(self.act_fn(self.up_proj(x)))