import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
# Define the expert model
class LinearExperts(nn.Module):
    def __init__(self, num_experts=5, input_size=28*28,  output_size=10):
        super(LinearExperts, self).__init__()
        # Experts parameters (weights and biases)
        self.w = nn.Parameter(torch.randn(num_experts, input_size, output_size))
        self.bias = nn.Parameter(torch.randn(num_experts, 1, output_size))

    def forward(self, x):
        norm = torch.norm(x, p=2, dim=1)
        # Normalizing the x, since we use the probit loss function
        x = torch.einsum("be, b -> be",  x, 1/norm) 

        # Computing the final outputs
        out = torch.einsum("bd, edo -> ebo",  x, self.w)
        out += self.bias
        return out


# Define SVM experts
class SVMExperts(nn.Module):
    def __init__(self, train_data_x, kernel='rbf',
                 gamma_init=1.0, train_gamma=True, num_experts=5, input_size=28*28,  output_size=10):
        super(SVMExperts, self).__init__()
        self._train_data_x = train_data_x
        self._num_c = input_size
        if kernel == 'linear':
            self._kernel = self.linear
            self._num_c = 2
        elif kernel == 'rbf':
            self._kernel = self.rbf
            self._num_c = train_data_x.size(0)
            self._gamma = torch.nn.Parameter(torch.FloatTensor([gamma_init]),
                                             requires_grad=train_gamma)
            
        # Experts parameters
        self.w = nn.Parameter(torch.randn(num_experts, self._num_c, output_size))
        self.bias = nn.Parameter(torch.randn(num_experts, 1, output_size))

    def rbf(self, x):
        # RBF kernel 
        y = self._train_data_x.repeat(x.size(0), 1, 1)
        return torch.exp(-self._gamma*((x[:,None]-y)**2).sum(dim=2))
    
    def linear(x):
        # Linear kernel 
        return x


    def forward(self, x):
        x = self._kernel(x)
        #print(self.w.shape, x.shape)
        norm = torch.norm(x, p=2, dim=1)
        x = torch.einsum("be, b -> be",  x, 1/norm) 
        out = torch.einsum("bd, edo -> ebo",  x, self.w)
        out += self.bias
        return out

    


# Define the gating network
class GatingNetwork(nn.Module):
    def __init__(self, num_experts, input_size, hidden_size=64, epsilon = 1, ldp_condition = False):
        super(GatingNetwork, self).__init__()
        self.epsilon = epsilon  
        self.ldp_condition = ldp_condition

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.tanh = nn.Tanh()
        self.fc3 = nn.Linear(hidden_size, num_experts)
        self.eps = torch.finfo(torch.float32).eps

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.tanh(x)
        if self.ldp_condition:
            # normalize the output to get the LDP condition, according to Theorem 2.2
            norm = torch.norm(x, p=2, dim=1)
            x = torch.einsum("be, b -> be",  x, 1/norm) 
        x = self.fc3(x)
        if self.ldp_condition:
            # normalize the output to get the LDP condition, according to Theorem 2.2
            w_norm = torch.norm(self.fc3.weight, p=2)
            x = x *self.epsilon/(4* w_norm+ self.eps)
        output = F.softmax(x, dim=1)
        return output



# Define the Mixtures of Experts model
class MoEModel(nn.Module):
    def __init__(self, input_size=28*28, output_size=10, num_experts=200, epsilon = 1, ldp_condition = False, expert_type='linear',train_data_x=None):
        super(MoEModel, self).__init__()
        self.epsilon = epsilon  
        self.num_experts = num_experts
        self.ldp_condition = ldp_condition
        if expert_type=='linear':
            self.experts = LinearExperts(num_experts, input_size, output_size)
        elif expert_type=='svm':
            self.experts = SVMExperts(train_data_x, kernel='rbf', num_experts=num_experts, input_size=input_size,  output_size=output_size)
        self.gating_network = GatingNetwork( num_experts, input_size, epsilon=self.epsilon, ldp_condition=self.ldp_condition)


    def forward(self, x):
        expert_outputs = self.experts(x)
        gating_weights = self.gating_network(x)
        return expert_outputs, gating_weights


    def get_experts_norm(self):
        # Get the norm of each expert
        norm = (torch.norm(self.experts.w, dim=1)**2+ torch.norm(self.experts.bias, dim=1)**2).flatten()
        return norm


