import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb

# Note: This is a simplified version of communication balance loss
# For the complete implementation with proper token-device mapping
# the device-limited routing implementation
# and more efficient calculations, please contact the author
class Expert(nn.Module):
    """
    Position-wise Feed-Forward Networks
    This consists of two linear transformations with a ReLU activation in between.
    
    FFN(x) = max(0, xW1 + b1 )W2 + b2
    d_model: embedding dimension (e.g., 512)
    d_expert: expert dimension (e.g., 256)
    
    """
    def __init__(self, d_model, d_expert, d_out):
        super().__init__()
        self.d_model=d_model
        self.d_expert= d_expert
        self.d_out = d_out

        # Linear transformation y = xW+b
        self.fc1 = nn.Linear(self.d_model, self.d_expert, bias = True)
        self.fc2 = nn.Linear(self.d_expert, self.d_out, bias = True)
        
        # for potential speed up
        # Pre-normalize the weights (can help with training stability)
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)

    def forward(self, input):
        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert self.d_model == d_input, "d_model must be the same dimension as the input"
        # max(0, xW_1 + b_1)W_2 + b_2 
        return self.fc2(F.relu(self.fc1(input))) 

class MixtureOfExperts(nn.Module):
    """
    Mixture of Expert as in DeepSeek
    
    MoE(x) = x + \sum Expert^s_i(x) + \sum gate(x;K)*Expert^r_i(x)
    d_model: embedding dimension (e.g., 512)
    d_expert: expert dimension (e.g., 216)
    K : top K gate
    N_s: number of shared experts
    N_r: number of routed experts
    alpha1: hyper-parameter; expert-level balance factor
    alpha2: hyper-parameter; edevice-level balance factor
    alpha3: hyper-parameter; communication balance factor
    D: number of device for distributed system
    M: number of device for Device-Limited Routing
    """
    def __init__(self, d_model, d_expert, d_out, K, N_s, N_r, alpha1, alpha2, alpha3, D=4, M=3):
        super().__init__()

        assert D < N_r, "Number of partitions needs to be less than number of routed experts"
        assert M <= D, "Number of deviced for Device-Limited Routing needs to be less than number of total device"
        
        self.d_model=d_model
        self.d_expert= d_expert
        self.d_out = d_out
        
        self.K = K
        self.N_s = N_s
        self.N_r = N_r
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.alpha3 = alpha3
        
        self.D = D # number of device available
        self.M = M # for Device-Limited Routing

        # initialize shared experts and routed experts
        self.shared_experts = nn.ModuleList([
            Expert(self.d_model, self.d_expert, self.d_out)
            for _ in range(N_s)
        ])
        
        self.routed_experts = nn.ModuleList([
            Expert(self.d_model, self.d_expert, self.d_out)
            for _ in range(N_r)
        ])
        
        # Initiate centroids: learnable parameters, one vector per routed expert
        self.expert_centroids = nn.Parameter(
            torch.randn(N_r, d_model)  # [num_routed_experts, d_model]
        )
        nn.init.xavier_uniform_(self.expert_centroids)

    
    def forward(self, input, input_gate = None, rdkit = None):
        if input_gate is None:
            input_gate = input

        # check input and first FF layer dimension matching
        batch_size, seq_length, d_input = input.size()
        assert self.d_model == d_input, "d_model must be the same dimension as the input"

        
        shared_output = torch.zeros_like(input)[:, :, :self.d_out]
        for expert in self.shared_experts:
            shared_output += expert(input) #[batch, seq, d_model]

        # Calculate similarity between input tokens and expert centroids
        self.similarities = torch.matmul(input_gate, self.expert_centroids.transpose(0, 1)) #[batch, seq, N_r]
        assert self.similarities.size(dim=-1) == self.N_r, \
        "last dimension of similarities must be the same as the number of routed expert"

        self.similarities_rdkit = torch.matmul(rdkit, self.expert_centroids.transpose(0, 1))
        self.similarities = self.similarities + self.similarities_rdkit

        affinity = F.softmax(self.similarities, dim = -1)  #[batch, seq, N_r]

        ## Apply topK to calculate the gate 
        values, indexes = torch.topk(affinity, self.K)
        values = F.softmax(values, dim=-1) # Renormalize the top-K values
        gate = torch.zeros_like(affinity).scatter_(2, indexes, values)  #[batch, seq, N_r]
        """for testing"""
        self.last_gate = gate

        all_experts = []
        routed_output = torch.zeros_like(input)[:, :, :self.d_out]
        for i in range(self.N_r):
            routed_output += gate[:,:,i].unsqueeze(-1) * self.routed_experts[i](input)
            all_experts.append(self.routed_experts[i](input))

        middle_expert = torch.cat(all_experts, dim=1)

        ## Auxiliary Loss for Load Balance 
        # Expert-Level Balance Loss.
        T = batch_size+seq_length
        f = self.N_r/(self.K*T) * torch.count_nonzero(gate,(0,1))
        P = 1/T * affinity.sum((0,1))
        expert_loss = self.alpha1 * torch.matmul(f,P)

        # Device-evel Balance Loss
        f1= torch.tensor([partition.to(f.dtype).mean() for partition in torch.tensor_split(f, self.D)])
        P1 = torch.tensor([partition.to(P.dtype).sum() for partition in torch.tensor_split(P, self.D)])
        device_loss = self.alpha2 * torch.matmul(f1,P1)

        # Communication Balance Loss
        f2 = self.D/(self.M*T)*torch.tensor( [ torch.count_nonzero(partition,(0,1)).sum() for partition in  torch.tensor_split(gate, self.D, dim=-1)] )
        P2 = P1 
        commu_loss = self.alpha3 * torch.matmul(f2,P2)

        #return input + shared_output + routed_output, expert_loss, device_loss, commu_loss
        return shared_output + routed_output, middle_expert, expert_loss, device_loss, commu_loss, affinity, indexes.squeeze() #gate
