import torch
from models.deepseek_moe import MixtureOfExperts
import pdb

class MLP(torch.nn.Module):
    def __init__(self, input_feat, dim_feat, num_tasks, num_layers=3, dropout=0.5, activation=torch.nn.ReLU()):
        super(MLP, self).__init__()

        layers = []
        layers.append(torch.nn.Linear(input_feat, dim_feat))
        layers.append(torch.nn.Dropout(dropout))
        layers.append(activation)
        
        for _ in range(num_layers - 2):
            layers.append(torch.nn.Linear(dim_feat, dim_feat))
            layers.append(torch.nn.Dropout(dropout))
            layers.append(activation)

        layers.append(torch.nn.Linear(dim_feat, num_tasks))
        self.model = torch.nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

class MLPMoE(torch.nn.Module):
    def __init__(self, input_feat, dim_feat, num_tasks, num_experts, num_heads, output_feat):
        super(MLPMoE, self).__init__()
        self.num_heads = num_heads
        self.moe = MixtureOfExperts(
                d_model=input_feat,     # Input dimension
                d_expert=dim_feat,      # Expert hidden dimension
                d_out = dim_feat, # // self.num_heads,
                K=2,                    # Top-K experts per token
                N_s=0,                  # Number of shared experts
                N_r=num_experts,                  # Number of routed experts
                alpha1=0.01,            # Expert balance factor
                alpha2=0.01,            # Device balance factor 
                alpha3=0.01,            # Communication balance factor
                D=1,                    # Number of devices
                M=1                     # Device limit for routing
            )

        self.out = torch.nn.Sequential(
            torch.nn.Linear(input_feat, dim_feat, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(dim_feat, 1, bias=True)
        )

        self.query = torch.nn.Linear(dim_feat, dim_feat * self.num_heads)
        self.num_tasks = num_tasks
        self.dim_feat = dim_feat

    def forward(self, x, rdkit_feat = None):
        x_input = x
        if len(x_input.shape) == 1:
            x_input = x_input.unsqueeze(0)
        
        f_moe, _, _, _, _, gate, expert_num = self.moe(x_input.unsqueeze(1), x_input.unsqueeze(1), rdkit_feat.unsqueeze(1))
        output_final = self.out(f_moe).squeeze(1)

        return output_final, gate, f_moe.squeeze(), expert_num
