import torch.nn.functional as F
from torch import nn
import torch


class MoELayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts, top_k):
        super(MoELayer, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_experts)])
        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        gate_scores = self.gate(x)
        gate_probs = F.softmax(gate_scores, dim=-1)
        top_k_gate_probs, top_k_indices = torch.topk(gate_probs, self.top_k, dim=-1)
        top_k_gate_probs = top_k_gate_probs / top_k_gate_probs.sum(dim=-1, keepdim=True)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        selected_expert_outputs = torch.gather(expert_outputs, 1,top_k_indices.unsqueeze(-1).expand(-1, -1, expert_outputs.size(-1)))
        output = (selected_expert_outputs * top_k_gate_probs.unsqueeze(-1)).sum(dim=1)
        return output


class CNNWithMoE(nn.Module):
    def __init__(self, num_classes=100, num_experts=5, topk=2, dropout_rate=0.05):
        super(CNNWithMoE, self).__init__()
        self.topk = topk
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.moe = MoELayer(256, 256, num_experts, topk)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        o = self.moe(x)
        x = self.fc2(o)
        return x


class CNNWithMoE_tiny(nn.Module):
    def __init__(self, num_classes=100, num_experts=5, topk=2, dropout_rate=0.05):
        super(CNNWithMoE_tiny, self).__init__()
        self.topk = topk
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)

        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc1 = nn.Linear(128 * 8 * 8, 256)
        self.moe = MoELayer(256, 256, num_experts, topk)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        o = self.moe(x)
        x = self.fc2(o)
        return x




class CNN_1(nn.Module): # for homo. exp.
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(CNN_1, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2* n_kernels, 5)
        self.fc1 = nn.Linear(2* n_kernels * 5 * 5, 2000)
        self.fc2 = nn.Linear(2000, 500)
        self.fc3 = nn.Linear(500, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        o = F.relu(self.fc2(x))
        x = self.fc3(o)
        return x


class CNN_tiny(nn.Module): # for homo. exp.
    def __init__(self, in_channels=3, n_kernels=16, out_dim=10):
        super(CNN_tiny, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, n_kernels, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(n_kernels, 2* n_kernels, 5)
        self.fc1 = nn.Linear(2* n_kernels * 13 * 13, 2000)
        self.fc2 = nn.Linear(2000, 500)
        self.fc3 = nn.Linear(500, out_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        o = F.relu(self.fc2(x))
        x = self.fc3(o)
        return x
