import torch
import torch.nn as nn
from einops import rearrange, repeat, reduce, pack, unpack
import torch.nn.functional as F
 
from .register import register_moe
from .moe import MoeLayer

def softmax_with_temperature(x, temp, **kwargs):
    return F.softmax(x / temp, dim=0, **kwargs)

def softmax_adaptive_temp(x, **kwargs):
    temp = torch.sum(x, dim=0) / x.shape[0]
    temp[temp == 0] = 1
    return F.softmax(x / temp, dim=0, **kwargs)

def normalized_adj(adj):
    degree = torch.sum(adj, dim=1)
    degree[degree == 0] = 1
    return torch.diag(torch.pow(degree, -0.5)) @ adj @ torch.diag(torch.pow(degree, -0.5))

@register_moe("smoe_graphgating")
class SMoeGraphGating(MoeLayer):
    def __init__(self, in_embed_dim=768, out_embed_dim=768, num_of_experts=4, num_selected=2, 
                 expert=None, args=None, gate_alpha=0.9, gate_beta=0.9, gate_norm_type="softmax", 
                 gate_softmax_temp=1.0, gate_sym=False, **kwargs):
        super().__init__(in_embed_dim, out_embed_dim, num_of_experts, num_selected, expert, args)
        self.alpha = gate_alpha
        self.beta = gate_beta
        self.sym = gate_sym
        self.is_warmuping = False
        self.register_buffer('adj', torch.zeros((num_of_experts, num_of_experts)), persistent=True)

        if gate_norm_type == "l1":
            self.adj_norm = lambda x: F.normalize(x, p=1, dim=0)
        elif gate_norm_type == "l2":
            self.adj_norm = lambda x: F.normalize(x, p=2, dim=0)
        elif gate_norm_type == "softmax":
            self.adj_norm = lambda x: softmax_with_temperature(x, temp=gate_softmax_temp)
        elif gate_norm_type == "normalized":
            self.adj_norm = normalized_adj
        elif gate_norm_type == "softmax_adaptive_temp":
            self.adj_norm = softmax_adaptive_temp
        else:
            raise ValueError(f"Invalid norm type: {gate_norm_type}")

    def update_adjacency(self, selected_experts):
        indices = selected_experts.view(-1, self.num_selected).t()
        values = torch.ones(indices.size(1), device=selected_experts.device)
        batch_adj = torch.sparse_coo_tensor(
            indices=indices,
            values=values,
            size=(self.num_of_experts, self.num_of_experts),
            device=selected_experts.device
        ).to_dense()

        if self.sym:
            batch_adj = batch_adj + batch_adj.t()
        batch_adj = self.adj_norm(batch_adj)
        
        if self.adj.sum() <= 1e-3:
            self.adj = batch_adj
        else:
            self.adj = self.beta * self.adj + (1 - self.beta) * batch_adj

    def topk_expert(self, gate_logits):
        gate_softmax = F.softmax(gate_logits, dim=-1, dtype=torch.float32)
        weights, selected_experts = torch.topk(gate_logits, self.num_selected)
        weights = F.softmax(weights, dim=-1)
        
        return weights, selected_experts, gate_softmax

    def forward(self, x, return_id_experts = False, is_vision = False):
        gate_logits = self.gate(x)
        weights, selected_experts, gate_softmax = self.topk_expert(gate_logits)
        
        if gate_logits.device != self.adj.device:
            self.adj = self.adj.to(gate_logits.device)
        
        if self.training:
            with torch.no_grad():
                self.update_adjacency(selected_experts)
        
        if self.is_warmuping:
            pass
        else:
            full_adj = self.alpha * torch.eye(self.num_of_experts, device=gate_logits.device) + \
                      (1 - self.alpha) * self.adj
            weights = weights @ full_adj.to(weights.dtype)

        output = torch.zeros(x.shape[0], x.shape[1], self.out_embed_dim, device=x.device, dtype=x.dtype)
        output = self.compute_moe(selected_experts, weights, output, x)
        auxiliary_loss, balance_loss = self.combine_loss(selected_experts, gate_softmax, gate_logits)

        if return_id_experts:
            return output, auxiliary_loss, selected_experts, balance_loss
        else:
            return output, auxiliary_loss, None, balance_loss
