import torch.nn.functional as F
from torch import nn
import torch


class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, top_k):
        super(MoELayer, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_experts)])
        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gate_scores = self.gate(x)
        gate_probs = F.softmax(gate_scores, dim=-1)
        top_k_gate_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_gate_probs = top_k_gate_probs / top_k_gate_probs.sum(dim=-1, keepdim=True)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        selected_expert_outputs = torch.gather(expert_outputs, 1,top_k_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(-1)))
        output = (selected_expert_outputs * top_k_gate_probs.unsqueeze(-1)).sum(dim=1)
        return output


class CNNWithoutMoE(nn.Module):
    def __init__(self, num_classes=100, dropout_rate=0.05):
        super(CNNWithoutMoE, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256,256)
        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        o = self.fc2(x)
        x = self.fc3(o)
        return x

class CNNWithMoE_last1(nn.Module):
    def __init__(self, num_classes=100,num_experts=5, topk=2,dropout_rate=0.05):
        super(CNNWithoutMoE, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256,256)
        self.moe = MoELayer(256,num_experts, topk, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        o = self.fc2(x)
        x = self.moe(o)
        return x