from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F

from fmoe.gates.base_gate import BaseGate

__all__ = [
    "CustomGlobalAdjGraphGate_Balance_SMoE_PerEpochUpdate",
    "CustomGlobalAdjGraphGate_Balance_XMoE_PerEpochUpdate"
]

def normalized_adj(adj):
    # caculate the degree of each node, fill 0 with 1 to avoid divide by zero
    degree = torch.sum(adj, dim=1)
    degree[degree == 0] = 1
    return torch.diag(torch.pow(degree, -0.5)) @ adj @ torch.diag(torch.pow(degree, -0.5))

def softmax_with_temperature(x, temp, **kwargs):
    return F.softmax(x / temp, dim=0, **kwargs)

def softmax_adaptive_temp(x, **kwargs):
    # get sum of the cols divided by the number of elements in the cols
    temp = torch.sum(x, dim=0) / x.shape[0]
    temp[temp == 0] = 1
    # divide each column by the corresponding temperature
    return F.softmax(x / temp, dim=0, **kwargs)

def pairwise_combinations(x):
    # x shape: [n, k]
    n, k = x.shape
    device = x.device
    # Generate indices for combinations (i, j) where i < j
    indices = torch.combinations(torch.arange(k, device=device), r=2)
    # Select elements for each combination
    a = x[:, indices[:, 0]]  # Shape [n, c]
    b = x[:, indices[:, 1]]  # Shape [n, c]
    # Stack and reshape
    combined = torch.stack([a, b], dim=2)  # Shape [n, c, 2]
    return combined.view(-1, 2)  # Shape [n * c, 2]

class CustomGlobalAdjGraphGate_Balance_SMoE_PerEpochUpdate(BaseGate):
    def __init__(self, d_model, num_expert, world_size, top_k=2, g_blance=False, 
                 threshold=0.5, sym=False, norm_type="", alpha=0.9, beta=0.9, gamma=0.9, layerth=0,
                 softmax_temp=1.0):
        super().__init__(num_expert, world_size)
        self.gate = nn.Linear(d_model, self.tot_expert)
        self.top_k = top_k
        self.dense_moe_flag = False
        self.g_blance = g_blance
        self.loss = None
        self.threshold = threshold
        self.sym = sym
        self.alpha = alpha
        self.beta = beta 
        self.register_buffer('adj', torch.zeros((self.tot_expert,self.tot_expert)))
        self.register_buffer('new_adj', torch.zeros((self.tot_expert,self.tot_expert)))
        self.layerth = layerth
        if norm_type == "l1":
            self.adj_norm = partial(F.normalize, p=1, dim=0)
        elif norm_type == "l2":
            self.adj_norm = partial(F.normalize, p=2, dim=0)
        elif norm_type == "softmax":
            self.adj_norm = partial(softmax_with_temperature, temp=softmax_temp)
        elif norm_type == "normalized":
            self.adj_norm = normalized_adj
        elif norm_type == "softmax_adaptive_temp":
            self.adj_norm = softmax_adaptive_temp
        else:
            raise ValueError(f"Invalid norm type: {norm_type}")
    
    def update_adj(self):
        """Method to update the adjacency matrix, will be called at the end of each epoch
        """
        # first time update adj matrix
        if self.adj.sum() <= 1e-3:
            self.adj = self.adj_norm(self.new_adj)
        else:
            self.adj = self.beta * self.adj + (1-self.beta) * self.adj_norm(self.new_adj)
        self.new_adj.zero_()

    def set_load_balance(self, gate, gate_top_k_idx):
        score = F.softmax(gate, dim=-1)
        valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
        fraction_expert = (
            torch.scatter_add(
                torch.zeros(self.tot_expert, device=valid_idx.device),
                0,
                valid_idx,
                torch.ones_like(valid_idx, dtype=torch.float),
            )
            / valid_idx.numel()
        )
        prob_expert = score.sum(dim=0) / valid_idx.numel()

        loss = (fraction_expert * prob_expert).sum() * self.tot_expert
        self.loss = loss

    def forward(self, inp, prev_gate_top_k_idx=None, prev_adj=None, return_all_scores=False, is_warmup=False, **kwargs):
        gate = self.gate(inp)
        
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=False
        )  # [.. x top_k]
        
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)  # (BxL) x 1 x top_k

        gate_score = F.softmax(gate_top_k_val, dim=-1)
        if self.training:
            with torch.no_grad():
                if self.top_k > 2:
                    gate_top_k_idx = pairwise_combinations(gate_top_k_idx)
                indices = gate_top_k_idx.t()  # Shape: [k, num_pairs]
                k = indices.shape[0]
                # get all pairs of k selected experts
                if k > 2:
                    # Create indices for all pairs
                    idx = torch.triu_indices(k, k, offset=1)
                    
                    # Combine the pairs (you can modify this operation based on your needs)
                    indices = torch.cat([indices[idx[0]], indices[idx[1]]], dim=1)
                    print(indices.shape)
                    
                # Create a tensor of ones with the same length as the number of pairs
                values = torch.ones(indices.size(1), device=gate_top_k_val.device)
                # Use sparse_coo_tensor to efficiently create the adjacency matrix
                batch_adj = torch.sparse_coo_tensor(
                    indices=indices,
                    values=values,
                    size=(self.tot_expert, self.tot_expert)
                ).to_dense().to(gate_top_k_val.device)
                if self.sym:
                    # Make the adjacency matrix symmetric
                    batch_adj = batch_adj + batch_adj.t()
                
                # accumulate the adjacency matrix of the current batch into new_adj
                self.new_adj += batch_adj

        # not consider adj matrix at the first epoch (all zeros)
        if is_warmup:
            pass
        else:
            # incorporate the self-effect into the adjacency matrix
            full_adj = self.alpha * torch.eye(self.tot_expert, device=gate_top_k_val.device) + (1-self.alpha) * self.adj
            gate_score = gate_score @ full_adj
        
        if self.g_blance:
            self.set_load_balance(gate, gate_top_k_idx)

        if return_all_scores:
            return gate_top_k_idx, gate_score, gate
        return gate_top_k_idx, gate_score

class CustomGlobalAdjGraphGate_Balance_XMoE_PerEpochUpdate(BaseGate):
    def __init__(self, d_model, num_expert, world_size, top_k=2, g_blance=False,
                 threshold=0.5, sym=False, norm_type="", alpha=0.9, beta=0.9, gamma=0.9, layerth=0,
                 softmax_temp=1.0):
        super().__init__(num_expert, world_size)
        self.gate = nn.Linear(d_model, self.tot_expert)
        self.top_k = top_k
        self.dense_moe_flag = False
        self.g_blance = g_blance
        self.loss = 0.0
        self.threshold = threshold
        self.sym = sym
        self.alpha = alpha
        self.beta = beta 
        self.register_buffer('adj', torch.zeros((self.tot_expert,self.tot_expert)))
        self.register_buffer('new_adj', torch.zeros((self.tot_expert,self.tot_expert)))
        self.layerth = layerth
        if norm_type == "l1":
            self.adj_norm = partial(F.normalize, p=1, dim=0)
        elif norm_type == "l2":
            self.adj_norm = partial(F.normalize, p=2, dim=0)
        elif norm_type == "softmax":
            self.adj_norm = partial(softmax_with_temperature, temp=softmax_temp)
        elif norm_type == "normalized":
            self.adj_norm = normalized_adj
        elif norm_type == "softmax_adaptive_temp":
            self.adj_norm = softmax_adaptive_temp
        else:
            raise ValueError(f"Invalid norm type: {norm_type}")

        expert_embeddings = torch.empty(num_expert, 8)
        torch.nn.init.orthogonal_(expert_embeddings, gain=0.32)
        self.register_parameter(
            "expert_embeddings", torch.nn.Parameter(expert_embeddings)
        )

        self.inp_reduction = torch.nn.Linear(d_model, 8, bias=False)

    def set_load_balance(self, gate, gate_top_k_idx):
        # gate_top_k_idx (tokens_number, top-k)
        # gate_top_k_val (tokens_number, top-k)

        score = F.softmax(gate / 0.3, dim=-1)
        valid_idx = gate_top_k_idx[gate_top_k_idx > -1]
        fraction_expert = (
            torch.scatter_add(
                torch.zeros(self.tot_expert, device=valid_idx.device),
                0,
                valid_idx,
                torch.ones_like(valid_idx, dtype=torch.float),
            )
            / valid_idx.numel()
        )
        prob_expert = score.sum(dim=0) / valid_idx.numel()

        loss = (fraction_expert * prob_expert).sum() * self.tot_expert
        self.loss = loss
        
    def cal_raw_gate(self, inp):
        reduced_inp = self.inp_reduction(inp)
        with torch.no_grad():
            expert_embeddings_norm = self.expert_embeddings.norm(
                p=2.0, dim=1, keepdim=True
            )
            self.expert_embeddings.mul_(1.5 / expert_embeddings_norm)

        gate = self._cosine(reduced_inp, self.expert_embeddings)
        gate = self._make_finite(gate)
        return gate

    def update_adj(self):
        """Method to update the adjacency matrix, will be called at the end of each epoch
        """
        # first time update adj matrix
        if self.adj.sum() <= 1e-3:
            self.adj = self.adj_norm(self.new_adj)
        else:
            self.adj = self.beta * self.adj + (1-self.beta) * self.adj_norm(self.new_adj)
        self.new_adj.zero_()

    def forward(self, inp, return_all_scores=False, is_warmup=False, **kwargs):
        gate = self.cal_raw_gate(inp)
        gate_top_k_val, gate_top_k_idx = torch.topk(
            gate, k=self.top_k, dim=-1, largest=True, sorted=not self.sym
        )  # [.. x top_k]
        
        gate_top_k_val = gate_top_k_val.view(-1, self.top_k)  # (BxL) x 1 x top_k
        gate_score = F.softmax(gate_top_k_val, dim=-1)
        
        if self.training:
            with torch.no_grad():
                indices = gate_top_k_idx.t()  # Shape: [2, num_pairs]
                # Create a tensor of ones with the same length as the number of pairs
                values = torch.ones(indices.size(1), device=gate_top_k_val.device)
                # Use sparse_coo_tensor to efficiently create the adjacency matrix
                batch_adj = torch.sparse_coo_tensor(
                    indices=indices,
                    values=values,
                    size=(self.tot_expert, self.tot_expert)
                ).to_dense().to(gate_top_k_val.device)
                if self.sym:
                    # Make the adjacency matrix symmetric
                    batch_adj = batch_adj + batch_adj.t()
                
                # accumulate the adjacency matrix of the current batch into new_adj
                self.new_adj += batch_adj

        if is_warmup:
            pass
        else:
            full_adj = self.alpha * torch.eye(self.tot_expert, device=gate.device) + (1-self.alpha) * self.adj
            gate_score = gate_score @ full_adj

        if self.g_blance:
            self.set_load_balance(gate, gate_top_k_idx)

        if return_all_scores:
            return gate_top_k_idx, gate_score, gate
        return gate_top_k_idx, gate_score

    def _cosine(self, mat1, mat2, eps=1e-4):
        assert mat1.dim() == 2
        assert mat2.dim() == 2
        # mat1 = F.normalize(mat1, p=2.0, dim=1, eps=eps)
        mat2 = F.normalize(mat2.float(), p=2.0, dim=1, eps=eps)
        return mat1.float().matmul(mat2.transpose(0, 1)).type_as(mat1)

    def _make_finite(self, scores):
        ok = scores.isfinite()
        if not ok.all():
            # NaNs here can break the assignment algorithm
            scores[~ok] = scores[ok].min()
        return scores
