import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from utilities.helpers import MyDataset


class MoEClassifier(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts=4, hidden_dim=128):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, output_dim)
            ) for _ in range(num_experts)
        ])
        self.gating_network = nn.Sequential(
            nn.Linear(input_dim, num_experts),
            nn.Softmax(dim=1)
        )

    def forward(self, x1, x2):
        h = torch.cat([x1, x2], dim=1)
        expert_outputs = torch.stack([expert(h) for expert in self.experts], dim=1)
        gate_weights = self.gating_network(h).unsqueeze(-1)
        output = (expert_outputs * gate_weights).sum(dim=1)
        return output


class MoE(nn.Module):
    def __init__(self, num_layers, in_channels):
        super().__init__()
        self.classifier = MoEClassifier(input_dim=in_channels * 2, output_dim=in_channels, num_experts=num_layers)

    def forward(self, device, data, llm, batch_size):
        result = {}
        data_loader = DataLoader(MyDataset(data, emb=None, emd2=llm), batch_size=batch_size, shuffle=True)
        y_pred, y_true = [], []
        for x1_batch, x2_batch, y_batch in data_loader:
            x1_batch = x1_batch.to(device)
            x2_batch = x2_batch.to(device)
            y_batch = y_batch.to(device)
            outputs = self.classifier(x1_batch, x2_batch)
            y_pred.append(outputs)
            y_true.append(y_batch)
        result['y_pred'] = torch.cat(y_pred, dim=0)
        result['y_true'] = torch.cat(y_true, dim=0)
        return result
