from typing import List

import torch
import torch.nn.functional as F
from torch import nn



class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.args = moe_args

    def forward(self, inputs: torch.Tensor):
        original_shape = inputs.shape
        if len(inputs.shape) > 2:
            # Flatten the input if it's not already 2D (batch_size, features)
            inputs = inputs.view(-1, original_shape[-1])
        gate_logits = self.gate(inputs)
        weights, selected_experts = torch.topk(gate_logits, self.args['num_experts_per_tok'], dim=-1)
        weights = F.softmax(weights, dim=-1).to(inputs.dtype)
        results = torch.zeros_like(inputs)

        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            # Flatten the expert inputs if necessary
            expert_input = inputs[batch_idx].view(-1, original_shape[-1]) if len(original_shape) > 2 else inputs[batch_idx]
            expert_output = expert(expert_input)
            # Scale the expert outputs by the assigned weights
            weighted_output = weights[batch_idx, nth_expert, None] * expert_output
            # Accumulate the results
            results[batch_idx] += weighted_output

        # Restore the original shape of the results if necessary
        if len(original_shape) > 2:
            results = results.view(original_shape)
        return results
