'''
Adapted from the original implementation of MixtralSparseMoeBlock in the Huggingface's repository.
Note that this is only an illustration of how to implement our method in the MixtralSparseMoeBlock,
and one should adapt this based on the training setup (e.g., gradient accumulation, distributed training, model architecture, etc.) they use.
(and this does not mean that we use Mixtral architecture in our experiments)
'''



class MixtralSparseMoeBlock(nn.Module):
    """
    Adapted from the original implementation of MixtralSparseMoeBlock in the Huggingface's repository.
    """

    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_size
        self.ffn_dim = config.intermediate_size
        self.num_experts = config.num_local_experts
        self.top_k = config.num_experts_per_tok
        self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
        self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)])
        self.jitter_noise = config.router_jitter_noise
        
        # ======== Added for our method =========
        self.bias = nn.Parameter(torch.zeros(self.num_experts), requires_grad=False) # this should not be updated by the optimizer
        self.bias_u = config.bias_u # bias update rate
        # =======================================


    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """ """
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        if self.training and self.jitter_noise > 0:
            hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits = self.gate(hidden_states)
        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)

        # ======== Aapted for our method =========
        # routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        _, selected_experts = torch.topk(routing_weights + self.bias.unsqueeze(0), self.top_k, dim=-1)
        routing_weights = routing_weights.gather(-1, selected_experts)
        # =======================================

        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_states.dtype)
        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)

        # ======== Added for our method =========
        # Note here the ci is supposed to be computed on the whole batch 
        # (the code here assumes no gradient accumulation / distributed training), 
        # you should adapt this that if you use gradient accumulation / distributed training
        ci = torch.bincount(selected_experts.flatten(), minlength=self.num_experts)
        delta_bias = (ci.mean() - ci).sign()
        self.bias.data = self.bias.data + self.bias_u * delta_bias
        # =======================================

        return final_hidden_states, router_logits