import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import ones_, zeros_
from torch_geometric.nn.inits import glorot
from torch_scatter import scatter
from typing import List
from utils import calculate_norm_A


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, bias=False):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
        
        self.bias = bias
        if self.bias:
            self.offset = nn.Parameter(torch.zeros(dim))

    def reset_parameters(self):
        ones_(self.weight)
        if self.bias:
            zeros_(self.offset)

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        if self.bias:
            return output * self.weight + self.offset
        else:
            return output * self.weight


class BiAttentionMoeLayer(nn.Module):
    def __init__(self, share_experts: List[nn.Module], router_experts: List[nn.Module], gate_weights: nn.Parameter, alpha_weights: nn.Parameter):
        super().__init__()
        if len(share_experts) > 0:
            self.share_experts = nn.ModuleList(share_experts)
        else:
            self.share_experts = None
        if len(router_experts) > 0:
            self.router_experts = nn.ModuleList(router_experts)
        else:
            self.router_experts = None
        self.gate = gate_weights
        self.alpha = alpha_weights

        self.router_weights = None

    def reset_parameters(self):
        if self.share_experts is not None:
            for expert in self.share_experts:
                expert.reset_parameters()
        if self.router_experts is not None:
            for expert in self.router_experts:
                expert.reset_parameters()
        zeros_(self.gate)
        zeros_(self.alpha)

    def forward(self, inputs, edge_masks, need_weights=False):
        
        num_ori_nodes = edge_masks[0][0].max() + 1
        num_cluster = edge_masks[1][0].max() + 1 - num_ori_nodes

        if self.share_experts is not None:
            results, _ = self.share_experts[0](
                inputs, edge_masks[0], need_weights)
        else:
            results = torch.zeros_like(inputs)

        weights = torch.zeros([inputs.size(0), len(self.router_experts)], device=inputs.device)
        weights[:num_ori_nodes, 0] = F.sigmoid((inputs[:num_ori_nodes] * self.gate).sum(dim=-1))
        weights[:num_ori_nodes, 1] = 1. - weights[:num_ori_nodes, 0]
        
        weights[num_ori_nodes:num_ori_nodes+num_cluster, 0] = 1.
        weights[num_ori_nodes+num_cluster:, 1] = 1.
        
        alpha = torch.zeros(inputs.size(0), device=inputs.device).unsqueeze(-1)
        alpha[:num_ori_nodes] = F.sigmoid((inputs[:num_ori_nodes] * self.alpha).sum(dim=-1, keepdim=True)) # [N, 1]
        
        results *= alpha
        for i, expert in enumerate(self.router_experts):
            out, _ = expert(inputs, edge_masks[i+1], need_weights)
            results += (1. - alpha) * weights[:,i].unsqueeze(-1) * out
        
        return results

class DualAttention(torch.nn.Module):
    def __init__(self, in_dim, h_dim, num_heads, dropout=0.1, bias=True, use_cache=False, agg_type="Trans"):
        super(DualAttention, self).__init__()
        self.n_head = 1 if agg_type == "GCN" else num_heads
        self.in_dim = in_dim
        self.h_dim = h_dim

        self.attn_dropout = nn.Dropout(p=dropout)
        self.agg_type = agg_type
        self.use_cache = use_cache

        if agg_type == "Trans":
            self.lin_q = nn.Linear(in_dim, h_dim, bias=bias)
            self.lin_k = nn.Linear(in_dim, h_dim, bias=bias)
            self.temperature = (h_dim // num_heads) ** 0.5
        elif agg_type == "GAT":
            self.lin_a = nn.Linear(in_dim, h_dim, bias=bias)
            self.alpha1 = nn.Parameter(torch.empty(
                1, num_heads, h_dim // num_heads))
            self.alpha2 = nn.Parameter(torch.empty(
                1, num_heads, h_dim // num_heads))
            self.leakyrelu = nn.LeakyReLU(negative_slope=0.2)
        elif agg_type == "GCN":
            self.gcn_cache = None

        self.lin_v = nn.Linear(in_dim, h_dim, bias=bias)

    def reset_parameters(self):
        self.lin_v.reset_parameters()

        if self.agg_type == "Trans":
            self.lin_q.reset_parameters()
            self.lin_k.reset_parameters()
        elif self.agg_type == "GAT":
            self.lin_a.reset_parameters()
            glorot(self.alpha1)
            glorot(self.alpha2)

    def dense_forward(self, q, k, v, edge_mask, out, need_weights=False):
        # qkv: [num_ori_nodes + num_cluster + num_global, n_head, head_dim]
        # edge_mask: a list of tensor[G1, G2, M]/[G1, M]
        #   G1 for query, G2 for key if G2 is not None else G1 for key
        #   M is the mask for attention if M is not None
        # one of [cluster2node, global2node, node2global]
        
        if len(edge_mask) == 3:
            G1 = edge_mask[0]
            G2 = edge_mask[1]
        else:
            G1 = edge_mask[0]
            G2 = edge_mask[0]

        M = edge_mask[-1]

        if self.agg_type == "Trans":
            q = q[G1]  # [N1, H, d]
            k = k[G2]  # [N2, H, d]
            q = q.transpose(0, 1)  # [H, N1, d]
            k = k.transpose(0, 1)  # [H, N2, d]
            attn = q @ k.transpose(-2, -1)  # [H, N1, N2]
        elif self.agg_type == "GAT":
            q = q[G1]  # [N1,H]
            k = k[G2]  # [N2,H]
            q = q.transpose(0, 1).unsqueeze(-1)  # [H, N1, 1]
            k = k.transpose(0, 1).unsqueeze(1)  # [H, 1, N2]
            attn = q + k  # [H, N1, N2]
            attn = self.leakyrelu(attn)
            
        if M is not None:
            attn = attn.masked_fill(M == False, -1e9)            

        attn = F.softmax(attn, dim=-1)
        attn = self.attn_dropout(attn)
        v = v[G2].transpose(0, 1)  # [H, N2, d]

        out[G1] = (
            attn @ v).transpose(0, 1).contiguous().view(G1.size(0), -1)  # [N1, Hd]

        if need_weights:
            return out, attn
        else:
            return out, None

    def sparse_forward(self, q, k, v, edge_index, out=None, need_weights=False):
        # qkv: [num_ori_nodes + num_cluster + num_global, n_head, head_dim]
        # edge_index: attention mask in the edge_index format

        if self.agg_type == "Trans":
            score = (q[edge_index[0]] * k[edge_index[1]]).sum(-1)  # [E, H]   # [2Ed]
        elif self.agg_type == "GAT":
            score = q[edge_index[0]] + k[edge_index[1]]  # [E, H]
            score = self.leakyrelu(score)
        elif self.agg_type == "GCN":
            if self.use_cache:
                if self.gcn_cache is None:
                    self.gcn_cache = calculate_norm_A(edge_index)
                    edge_index, attn = self.gcn_cache
                else:
                    edge_index, attn = self.gcn_cache                  
            else:
                edge_index, attn = calculate_norm_A(edge_index)

        if self.agg_type in ["GAT", "Trans"]:
            score_max = scatter(score.detach(), index=edge_index[0], dim=0, reduce="max")  # [N, H] # [E]
            score = score - score_max[edge_index[0]]
            score = torch.exp(score.unsqueeze(-1))  # [E, H, 1]
            score_sum = scatter(
                score, index=edge_index[0], dim=0, reduce="sum")  # [N, H, 1] # [E] 
            attn = score / score_sum[edge_index[0]]  # [E, H, 1]
            attn = self.attn_dropout(attn)

        msg = attn * v[edge_index[1]]  # [E, H, d]

        out = out.view(-1, self.n_head, self.h_dim // self.n_head)
        out = scatter(msg, index=edge_index[0], dim=0, out=out,
                      reduce="sum").view(-1, self.h_dim)  # [N, Hd] # [Ed]

        if need_weights:
            return out, attn
        else:
            return out, None

    def forward(self, x, edge_mask, need_weights=False):
        # x: [num_ori_nodes + num_cluster + num_global, h_dim]
        # edge_mask: one of [[edge_index (tensor)], [node2cluster (tensor), cluster2node (list)], [node2global (list), global2node (list)]]
        N = x.size(0)
        d = self.h_dim // self.n_head

        if self.agg_type == "Trans":
            q = self.lin_q(x).view(N, self.n_head, d) / \
                self.temperature  # [N, H, d]
            k = self.lin_k(x).view(N, self.n_head, d)  # [N, H, d]
        elif self.agg_type == "GAT":
            h = self.lin_a(x).view(N, self.n_head, d)  # [N, H, d]
            q = (h * self.alpha1).sum(-1)  # [N, H]
            k = (h * self.alpha2).sum(-1)  # [N, H]
        elif self.agg_type == "GCN":
            q = None
            k = None
        v = self.lin_v(x).view(N, self.n_head, d)  # [N, H, d]

        out = torch.zeros(x.size(0), self.h_dim, device=x.device)
        for em in edge_mask:
            if type(em) == torch.Tensor:
                out, attn = self.sparse_forward(q, k, v, em, out, need_weights)
            else:
                out, attn = self.dense_forward(q, k, v, em, out, need_weights)
        return out, attn


class M3DLayerL(torch.nn.Module):
    def __init__(self, in_dim, h_dim, n_head, dropout=0.1, attn_dropout=0.1, bias=True, use_cache=False, use_res=False, norm_type="ln", norm_pos="pre", local_type="GAT"):
        super(M3DLayerL, self).__init__()
        self.n_head = n_head
        self.h_dim = h_dim

        self.norm_pos = norm_pos

        self.use_res = use_res

        if norm_pos == "pre":
            if norm_type == "ln":
                self.norm1 = nn.LayerNorm(h_dim, eps=1e-6)
            elif norm_type == 'rms':    
                self.norm1 = RMSNorm(h_dim, eps=1e-6, bias=True)
            elif norm_type == "bn":
                self.norm1 = nn.BatchNorm1d(in_dim)
            else:
                self.norm1 = None
        elif norm_pos == "post":
            if norm_type == "ln":
                self.norm1 = nn.LayerNorm(h_dim, eps=1e-6)
            elif norm_type == 'rms':    
                self.norm1 = RMSNorm(h_dim, eps=1e-6, bias=True)
            elif norm_type == "bn":
                self.norm1 = nn.BatchNorm1d(h_dim)
            else:
                self.norm1 = None
        # self.dropout2 = nn.Dropout(p=dropout)

        self.attn_moe = BiAttentionMoeLayer(
            share_experts=[DualAttention(
                in_dim=in_dim, h_dim=h_dim, num_heads=n_head, dropout=attn_dropout, bias=bias, use_cache=use_cache, agg_type=local_type)],
            router_experts=[DualAttention(in_dim=in_dim, h_dim=h_dim, num_heads=n_head, dropout=attn_dropout, bias=bias, agg_type="Trans"),
                            DualAttention(in_dim=in_dim, h_dim=h_dim, num_heads=n_head, dropout=attn_dropout, bias=bias, agg_type="Trans")],
            gate_weights=nn.Parameter(torch.zeros([1, in_dim])),
            alpha_weights=nn.Parameter(torch.zeros([1, in_dim])),
        )

        self.dropout1 = nn.Dropout(p=dropout)

        if self.use_res:
            self.res = nn.Linear(in_dim, h_dim, bias=bias)

        self.act = nn.ReLU()

    def reset_parameters(self):
        if self.norm1 is not None:
            self.norm1.reset_parameters()
        self.attn_moe.reset_parameters()
        if self.use_res:
            self.res.reset_parameters()
            
    def forward(self, x, edge_masks, need_weights=False):
        residual = x
        if self.norm_pos == "pre" and self.norm1 is not None:
            x = self.norm1(x)
        x = self.attn_moe(x, edge_masks=edge_masks, need_weights=need_weights)

        if self.use_res:
            x = x + self.res(residual)
        if self.norm_pos == "post" and self.norm1 is not None:
            x = self.norm1(x)
        x = self.dropout1(self.act(x))

        return x


class M3DphormerL(torch.nn.Module):
    def __init__(self, n_ori_nodes, n_cluster, n_global, x_dim, h_dim, n_cls, n_head, layers, dropout, attn_dropout, local_type="GAT", learn_global=False, use_cache=False, use_res=True, norm_type="ln", norm_pos="pre"):
        super(M3DphormerL, self).__init__()

        self.n_ori_nodes = n_ori_nodes
        self.n_cluster = n_cluster
        self.n_global = n_global
        self.learn_global = learn_global

        self.lin = nn.Linear(x_dim, h_dim)

        self.layers = nn.ModuleList()
        for _ in range(layers):
            self.layers.append(M3DLayerL(
                h_dim, h_dim, n_head, dropout, attn_dropout, local_type=local_type, use_cache=use_cache, use_res=use_res, norm_type=norm_type, norm_pos=norm_pos))

        self.cls = nn.Linear(h_dim, n_cls)
        if self.learn_global:
            self.global_embed = nn.Embedding(n_global, h_dim)
        self.in_dropout = nn.Dropout(p=dropout)
        self.in_act = nn.ReLU()

    def reset_parameters(self):
        self.lin.reset_parameters()
        if self.learn_global:
            self.global_embed.reset_parameters()
        for l in self.layers:
            l.reset_parameters()
        self.cls.reset_parameters()

    def forward(self, x, edge_masks):
        x = self.lin(x)  # [num_ori_nodes + num_cluster, h_dim]
        x = self.in_dropout(x + self.in_act(x))
        
        if self.learn_global:
            x[-self.n_global:] = self.global_embed.weight
        
        out = 0.
        out += x        
        for l in self.layers:
            x = l(x, edge_masks)
            out += x

        return self.cls(out)

    def cls_loss(self, out, y, y_global, idx, criterion):
        # out: [num_ori_nodes + num_cluster + num_global, n_cls]
        loss_local = criterion(out[idx], y[idx].squeeze(-1))

        loss_global = criterion(
            out[self.n_ori_nodes+self.n_cluster:], y_global.squeeze(-1))
        
        gamma = self.n_global / idx.size(0)

        return loss_local, gamma * loss_global