import torch
import torch.nn as nn
from torch_geometric.utils import to_torch_coo_tensor
import torch.nn.functional as F
from visio import data_show


# 最终生成一个反映节点之间信息传播关联程度或者权重的矩阵
class PathIntegral(nn.Module):
    def __init__(self, q_dim, n_q):
        super(PathIntegral, self).__init__()
        self.lambda_copies = nn.Parameter(torch.randn(n_q, 1, 1))
        self.n_q = n_q
        self.q_dim = q_dim
        self.in_subsystem = None
        self.out_subsystem = None

    def _path_integral(self, in_subsystem=None, out_subsystem=None):
        if in_subsystem == None:
            in_subsystem = self.in_subsystem
        if out_subsystem == None:
            out_subsystem = self.out_subsystem
        if in_subsystem == None or out_subsystem == None:
            raise ValueError("Path integral requires computational object.")

        if in_subsystem.shape[0] == 1:
            # TODO: ablation study between attention and dist
            if in_subsystem.shape[-3] == 1 and out_subsystem.shape[-3] == self.n_q:
                out_subsystem = out_subsystem * self.lambda_copies
                out_subsystem_sum = out_subsystem.sum(-3) / (self.n_q * self.q_dim)
                weighted_dist_sum = torch.matmul(in_subsystem.squeeze(-3),
                                                 out_subsystem_sum.transpose(-2, -1))
            elif in_subsystem.shape[-3] == self.n_q and out_subsystem.shape[-3] == 1:
                in_subsystem = in_subsystem * self.lambda_copies
                in_subsystem_sum = in_subsystem.sum(-3) / (self.n_q * self.q_dim)
                weighted_dist_sum = torch.matmul(in_subsystem_sum,
                                                 out_subsystem.squeeze(-3).transpose(-2, -1))

            #     新增注释
            elif in_subsystem.shape[-3] == self.n_q and out_subsystem.shape[-3] == self.n_q:
                dist = torch.matmul(in_subsystem.sum(-3),
                                                 out_subsystem.sum(-3).transpose(-2, -1))
                dist = dist.unsqueeze(1)
                weighted_dist = dist * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)
            else:
                # 执行计算的地方
                dist = torch.matmul(in_subsystem, out_subsystem.transpose(-2, -1)) / self.q_dim
                # 1*256*7546---->1*8*256*7546 先扩充再求和
                weighted_dist = dist * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)

        else:
            # [64,8,5,64] * [64,8,5,64] --> [64,5,5]
            if in_subsystem.shape[-3] == 1 and out_subsystem.shape[-3] == 8:
                out_subsystem = out_subsystem * self.lambda_copies
                out_subsystem_sum = out_subsystem.sum(-3) / (self.n_q * self.q_dim)
                weighted_dist_sum = torch.matmul(in_subsystem.squeeze(-3),
                                                 out_subsystem_sum.transpose(-2, -1))
            elif in_subsystem.shape[-3] == 8 and out_subsystem.shape[-3] == 1:
                in_subsystem = in_subsystem * self.lambda_copies
                in_subsystem_sum = in_subsystem.sum(-3) / (self.n_q * self.q_dim)
                weighted_dist_sum = torch.matmul(in_subsystem_sum,
                                                 out_subsystem.squeeze(-3).transpose(-2, -1))

            elif in_subsystem.shape[-3] == 8 and out_subsystem.shape[-3] == 8:
                # 执行计算的地方
                # dist = torch.matmul(in_subsystem, out_subsystem.transpose(-2, -1)) / self.q_dim
                # # 1*256*7546---->1*8*256*7546 先扩充再求和
                # weighted_dist = dist * self.lambda_copies
                # weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)
                dist = torch.matmul(in_subsystem.sum(-3),
                                                     out_subsystem.sum(-3).transpose(-2, -1))
                dist = dist.unsqueeze(1)
                weighted_dist = dist * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)
            # [64,5,64] * [64,1,20,64]  --> [64,5,20]
            # [64,5,64] * [64,5,64]     --> []
            elif len(in_subsystem.shape)==3 and len(out_subsystem.shape)==4:
                # 执行计算的地方
                dist = torch.matmul(in_subsystem, out_subsystem.squeeze(1).transpose(-2, -1)) / self.q_dim
                # 1*256*7546---->1*8*256*7546 先扩充再求和
                weighted_dist = dist.unsqueeze(1) * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)
            elif len(in_subsystem.shape)==4 and len(out_subsystem.shape)==3:
                # 执行计算的地方
                dist = torch.matmul(in_subsystem.squeeze(1), out_subsystem.transpose(-2, -1)) / self.q_dim
                # 1*256*7546---->1*8*256*7546 先扩充再求和
                weighted_dist = dist.unsqueeze(1) * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)
            else:
                # （64,5,64）*（64,5,64）
                dist = torch.bmm(in_subsystem, out_subsystem.transpose(-2, -1)) / self.q_dim
                dist = dist.unsqueeze(1)
                weighted_dist = dist * self.lambda_copies
                weighted_dist_sum = weighted_dist.sum(-3) / self.n_q  # (n_in, n_out)

        # dist clip
        # weighted_dist_sum = torch.tanh(weighted_dist_sum)
        with torch.no_grad():
            clip_check = torch.abs(weighted_dist_sum.sum(-1, keepdim=True))
            fill_value = torch.ones_like(clip_check)
            scaler = torch.where(clip_check > 1e+4, 1e+4 / clip_check, fill_value)
        weighted_dist_sum = scaler * weighted_dist_sum
        return weighted_dist_sum

    #  in_sub--pnode,out_sub--node
    def forward(self, in_subsystem, out_subsystem):
        return self._path_integral(in_subsystem, out_subsystem)


# 用于处理伪节点之间以及伪节点与全局信息之间的信息交互、特征转换等操作
# 传入 伪节点和图节点发送的消息G
# 返回伪节点的位移量&R 和伪节点要发给图消息Mp
class PNodeCommunicator(nn.Module):
    def __init__(self, d_in, d_out, q_dim, n_q, dropout):
        super(PNodeCommunicator, self).__init__()
        self.q_dim = q_dim
        self.pnode_agg = PathIntegral(q_dim, n_q)

        # 计算伪节点位移量的一个简单函数
        self.glob2disp = nn.Sequential(nn.Linear(d_in, q_dim * n_q), 
                                        nn.LeakyReLU(), 
                                        nn.Dropout(dropout))

        self.glob2value = nn.Sequential(nn.Linear(d_in, d_out), 
                                        nn.LeakyReLU(), 
                                        nn.Dropout(dropout))

    def forward(self, state, glob):
        # Epp
        glob_updater = self.pnode_agg(state, state)  # (n_pnode, n_pnode)
        # G^
        glob_update = torch.matmul(glob_updater, glob)  # (n_pnode, d_in)

        # 表示伪节点在不同 “通道” 或者维度下的位移、变化等特征信息
        displacement = self.glob2disp(glob_update)  # (n_pnode, q_dim)
        displacement = displacement.unflatten(-1, (self.q_dim, -1))
        displacement = displacement.permute(0, 3, 1, 2)

        # Mp 生成伪节点要返回的消息
        dispatch_value = self.glob2value(glob_update)  # (n_pnode, d_out)

        return displacement, dispatch_value


class NodePseudoSubsystem(nn.Module):
    '''
    Neurons as Nodes: get nodes' neuronal state through neurons as nodes
    神经元作为节点：通过神经元作为节点获取节点的神经元状态
    '''
    def __init__(self, d_in, d_ein, d_out, n_pnode, d_model, q_dim, n_q, dropout=0.0, norm=True):
        super(NodePseudoSubsystem, self).__init__()
        self.collection1 = PathIntegral(q_dim, n_q)
        self.pnode_agg1 = PNodeCommunicator(d_model, d_model, q_dim, n_q, dropout)

        self.inspection = PathIntegral(q_dim, n_q)
        self.edge_wise_ff = nn.Linear(d_model, d_model)
        # self.hstate_interface = nn.Sequential(nn.Linear(d_model * 2 + q_dim, q_dim),
        #                                       nn.LeakyReLU(),
        #                                       nn.Dropout(dropout))

        # new
        self.hstate_interface = nn.Sequential(nn.Linear(d_model  , q_dim),
                                              nn.LeakyReLU(),
                                              nn.Dropout(dropout))

        self.collection2 = PathIntegral(q_dim, n_q)
        self.pnode_agg2 = PNodeCommunicator(d_model * 3 + q_dim, d_out, q_dim, n_q, dropout)
        
        self.dispatch = PathIntegral(q_dim, n_q)
        self.feat_ff = nn.Sequential(nn.Linear(q_dim, d_model), 
                                     nn.LeakyReLU(), 
                                     nn.Dropout(dropout))
        if norm:
            self.phidden_norm = nn.LayerNorm(q_dim)
            self.hidden_norm = nn.LayerNorm(q_dim)
            self.pout_norm = nn.LayerNorm(q_dim)
            self.out_norm = nn.LayerNorm(q_dim)
            self.feat_norm = nn.LayerNorm(d_model)
        self.norm = norm
        print(f"Using norm: {norm}" )

        self.time_embedding = nn.Embedding(6, d_model)
        self.q_dim = q_dim
        self.pnode_num = n_pnode
        self.d_model = d_model
        self.UpdateFeatures = UpdateFeatures(n_features=64)
        self.UpdateNodeEmbeddingLayer = UpdateNodeEmbeddingLayer(n_features=64)


    # 整体包含了从节点到伪节点方向的信息聚合与伪节点状态更新，
    # 以及从伪节点返回至节点方向的信息生成这两个阶段，旨在通过这样的双向信息交互过程挖掘和更新节点与伪节点相关的特征表示
    # 返回Mglobal,pnode_state
    def _feature_inspection(self, features, node_state, pnode_state, node_num):
        # init feature inspection (node to pnode, pnode-level learning)
        # 计算节点到伪节点的距离Enp
        ipn2n_dist = self.collection1(pnode_state, node_state)  # (n_pnode, n)

        # 基于节点与伪节点之间的关联关系，将节点层面的特征信息聚合到伪节点层面，得到每个伪节点对应的全局初始化信息
        glob_init = torch.matmul(ipn2n_dist, features) / node_num  # (n_pnode, d_model)

        # 实现了基于全局初始化信息对伪节点状态的更新
        # 返回伪节点位移量，和为节点整合的消息
        pnode_disp1, self.str_inspector = self.pnode_agg1(pnode_state, glob_init)

        # 更新pnode的状态
        if pnode_disp1.shape[0]==64 and len(pnode_state.shape)==3:
            pnode_state = pnode_disp1 + pnode_state.unsqueeze(1).repeat(1,8,1,1)
        else:
            pnode_state = pnode_disp1 + pnode_state

        if self.norm:
            pnode_state = self.phidden_norm(pnode_state)

        # inspector dispatch (pnode to node, node-level learning)
        # 保存当前伪节点状态
        self.pnode_state = pnode_state

        # 计算节点到伪节点的反向距离
        # inspection --integral
        n2ipn_dist = self.inspection(node_state, pnode_state)  # (n, n_pnode)

        # str_inspector->Mp
        # Mglobal = Enp * Mp
        inspector = torch.matmul(n2ipn_dist, self.str_inspector)  # (n, d_model)
        return inspector, pnode_state


    def _pnode_aggregator(self, pnode_state, hnode_state, insp_out, node_num):
        # feature collection (node to pnode)
        # 计算R^和Q^的相似度
        opn2n_dist = self.collection2(pnode_state, hnode_state)  # (n_pnode, n)

        # Q^-->R^
        glob_info = torch.matmul(opn2n_dist, insp_out) / node_num  # (n_pnode, d_model * 2)

        # str_inspector->Mp
        glob_info = torch.concat((glob_info, self.str_inspector), -1)

        # pnode-level feature refinement (pnode-level learning)
        # 伪节点发消息给伪节点,返回伪节点R^的位移量和伪节点整合的消息MP^
        pnode_disp2, dispatch_value = self.pnode_agg2(pnode_state, glob_info)  # (n_pnode, n_pnode)
        # 更新伪节点R(l) = Rˆ(l) + ∆R(l)
        pnode_state = pnode_state + pnode_disp2

        if self.norm:
            pnode_state = self.pout_norm(pnode_state)

        # 计算R^和Q^的关联程度
        n2opn_dist = self.dispatch(hnode_state, pnode_state)  # (b_s, n, n_pnode)
        # 返回新的dispatch_value-->Mglob(l)
        dispatch_value = torch.matmul(n2opn_dist, dispatch_value)
        return dispatch_value, pnode_state


    def _edge_aggregation(self, insp_in, edge_weight):

        # 执行特征聚合，节点特征在图结构上的传播与聚合
        insp_out = torch.bmm(edge_weight, insp_in)  # (b_s * n, 2 * d_model)
        return insp_out

    def forward(self, *,
                         edge_weight=None,
                         edge_attr=None,
                         features=None, 
                         node_state=None, 
                         pnode_state=None, 
                         mask=None,
                        norm=None,
                         size=None):
        # node_state (b_s, n, q_dim), features (b_s, n, d_model)
        b_s, n = features.shape[:2]
        node_num = n if mask is None else mask.sum(-2, keepdim=True)
        
        # insp->Mglobal
        # 完成全局消息传递
        insp, pnode_state = self._feature_inspection(features, 
                                                     node_state.unsqueeze(1),
                                                     pnode_state,
                                                     node_num)


        # data_show(node_state,pnode_state,title="first update pnode")

        # feature inspection
        # 7650*320
        insp_in = torch.concat((features, insp, node_state), -1)

        features = self.UpdateFeatures(insp_in, norm, edge_weight)

        hnode_state = self.hstate_interface(features)
        node_state = hnode_state + node_state

        return node_state, pnode_state, features


class UpdateFeatures(nn.Module):

    def __init__(self, n_features):
        super().__init__()
        self.message_layer = nn.Linear(5 * n_features,2 * n_features, bias=False)
        self.update_layer = nn.Linear(7 * n_features, 2 * n_features, bias=False)

    def forward(self, current_node_embeddings, norm, adj):  # edge_embeddings,
        node_embeddings_aggregated = torch.matmul(adj, current_node_embeddings) / norm  # 64*20*64

        message = F.relu(
            self.message_layer(node_embeddings_aggregated))  # 64*20*64

        new_node_embeddings = F.relu(
            self.update_layer(torch.cat([current_node_embeddings, message], dim=-1)))  # 64*20*64

        return new_node_embeddings

class UpdateNodeEmbeddingLayer(nn.Module):

    def __init__(self, n_features):
        super().__init__()
        self.message_layer = nn.Linear( n_features, n_features, bias=False)
        self.update_layer = nn.Linear(2 * n_features, n_features, bias=False)

    def forward(self, current_node_embeddings, norm, adj):  # edge_embeddings,
        node_embeddings_aggregated = torch.matmul(adj, current_node_embeddings) / norm  # 64*20*64

        message = F.relu(
            self.message_layer(node_embeddings_aggregated))  # 64*20*64

        new_node_embeddings = F.relu(
            self.update_layer(torch.cat([current_node_embeddings, message], dim=-1)))  # 64*20*64

        return new_node_embeddings