import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pyexpat import features
from torch_geometric.data import Batch
from torch_geometric.utils import (to_dense_batch, 
                                   add_remaining_self_loops, 
                                   scatter)
from torch_geometric.utils.num_nodes import maybe_num_nodes

from model.layers import PathIntegral
from model.layers import NodePseudoSubsystem
from visio import data_show
class N2(nn.Module):
    def __init__(self, *, d_in=1,
                          d_ein=0, 
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          n_q=8, 
                          n_c=8,
                          n_pnode=5,
                          T=1,
                          n_obs_in=6,
                          n_features=64,
                          task_type="single-class",
                          self_loop=True,
                          pre_encoder=None, 
                          pos_encoder=None, 
                          dropout=0.1):

        super(N2, self).__init__()

        d_in = d_in + d_ein

        self.node_state_interface = nn.Sequential(nn.Linear(d_in, q_dim),
                                                  nn.LeakyReLU(),
                                                  nn.Dropout(dropout))

        self.feat_ff = nn.Sequential(nn.Linear(d_in, d_model),
                                     nn.LeakyReLU(),
                                     nn.Dropout(dropout))

        self.pnode_state = nn.Parameter(torch.randn(1, n_pnode, q_dim))

        if task_type != "reg":
            self.node_state_updater = NodePseudoSubsystem(d_in, d_ein, q_dim, n_pnode, d_model, q_dim, n_q, dropout)
            self.class_neuron = nn.Parameter(torch.randn(n_c, nclass, q_dim))
            self.out_ff = PathIntegral(q_dim, n_q)

        else:
            self.node_state_updater = NodePseudoSubsystem(d_in, d_ein, q_dim, n_pnode, d_model, q_dim, n_q, dropout, False)

        self.pre_encoder = pre_encoder 
        self.pos_encoder = pos_encoder 
        self.task_type = task_type  
        self.T = T
        self.n_q = n_q
        self.q_dim = q_dim
        self.n_pnode = n_pnode
        self.d_model = d_model
        self.self_loop = self_loop

    # 根据adj，获得每个节点的度数
    def get_degree(self, adj):
        norm = torch.sum((adj != 0), dim=1).unsqueeze(-1)  # 邻接矩阵每一行不等于0的权值的个数，即相邻顶点的个数
        # norm[norm == 0] = 1  # 个数为0则赋值为1，因为要除以norm，所以让norm不等于0
        return norm.float()

    # * 表示后续参数只能作为 关键字参数 提供
    def _get_sparse_normalized_adj(self,  adj):
        # print("adj.shape",adj.shape)
        # 计算每个节点的度
        deg = self.get_degree(adj)
        deg_inv_sqrt = deg.pow_(-0.5)

        # 处理孤立节点的度数倒数平方根，以避免无穷大（inf）值对后续计算的影响
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

        # 边权重为两个节点的度负倒数平方根相乘
        B=deg_inv_sqrt.transpose(1, 2)
        edge_weight = torch.bmm(deg_inv_sqrt, B)
        # edge_weight = edge_weight * adj
        # 输出结果
        # print("edge_weight.shape",edge_weight.shape)
        return edge_weight

    @torch.no_grad()
    def get_normalisation(self, adj):
        norm = torch.sum((adj != 0), dim=1).unsqueeze(-1)  # 邻接矩阵每一行不等于0的权值的个数，即相邻顶点的个数
        norm[norm == 0] = 1  # 个数为0则赋值为1，因为要除以norm，所以让norm不等于0
        return norm.float()

    def calculate_edge_weight_batch(self,adj_matrix_batch, degrees_batch):
        """
        此函数用于计算批次中图的边的权重
        :param adj_matrix_batch: 批次的邻接矩阵，形状为 (batch_size, num_vertices, num_vertices)
        :param degrees_batch: 批次的顶点度数组，形状为 (batch_size, num_vertices)
        :return: 批次的边的权重矩阵，形状为 (batch_size, num_vertices, num_vertices)
        """
        batch_size, num_vertices, _ = adj_matrix_batch.shape
        edge_weight_batch = []
        for i in range(batch_size):
            adj_matrix = adj_matrix_batch[i]
            degrees = degrees_batch[i]
            degrees_matrix = torch.outer(degrees.squeeze(), degrees.squeeze())
            edge_weight = torch.where(adj_matrix == 1, degrees_matrix, 1)
            edge_weight_batch.append(edge_weight)
        return torch.stack(edge_weight_batch)

    def symmetric_normalization(self,adj_matrix_batch):
        """
        该函数用于对输入的一批次邻接矩阵进行对称归一化。
        参数：
        adj_matrix_batch: 输入的一批次邻接矩阵，形状为 (batch_size, n, n)
        """
        batch_size, n, n = adj_matrix_batch.size()
        # 计算度矩阵
        degree_matrix = torch.sum(adj_matrix_batch, dim=2)
        # 为避免除以零，将度矩阵中的零元素加上一个小的正数，例如 1e-7
        degree_matrix = degree_matrix + 1e-7
        # 计算度矩阵的 -1/2 次幂
        degree_matrix = torch.pow(degree_matrix, -0.5)
        # 将度矩阵的 -1/2 次幂转换为对角矩阵
        degree_matrix = torch.diag_embed(degree_matrix)
        # 进行对称归一化操作
        normalized_adj_matrix_batch = torch.bmm(
            torch.bmm(degree_matrix, adj_matrix_batch), degree_matrix)
        return normalized_adj_matrix_batch

    def batch_compute_edge_weights(self,adj_matrix: torch.Tensor) -> torch.Tensor:

        # 计算节点度数 (batch, n)
        degree_tensor = adj_matrix.sum(dim=-1)

        # 1. 计算度数的平方根
        sqrt_degree = torch.sqrt(degree_tensor)

        # 2. 计算平方根的倒数，并过滤无穷大值
        inv_sqrt_degree = 1.0 / sqrt_degree  # 度数为0的位置会产生inf
        inv_sqrt_degree = torch.nan_to_num(inv_sqrt_degree, posinf=0.0, neginf=0.0)  # 将inf替换为0

        # 3. 生成边权重矩阵（外积）
        edge_weight = inv_sqrt_degree.unsqueeze(2) * inv_sqrt_degree.unsqueeze(1)
        return edge_weight

    def _feature_prep(self, state):
        if(state.dim() == 2):
            state = state.unsqueeze(0)
        state.transpose_(-1,-2)
        # b_s批次，n节点个数
        b_s, n = state.shape[:2]
        node_state = state[:,:,0:6]
        # node_state = state
        adj = state[:,:,6:]

        # 将对角线元素设置为 1,添加自己到自己的边
        for i in range(b_s):  # 遍历每个二维矩阵
            adj[i].fill_diagonal_(1)  # 设置对角线元素为 1

        norm = self.get_normalisation(adj)
        new_matrix = torch.where(adj == 0, torch.tensor(0.01), torch.tensor(0.99))
        # deg_inv_sqrt = norm.pow_(-0.5)
        #
        # # 处理孤立节点的度数倒数平方根，以避免无穷大（inf）值对后续计算的影响
        # deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

        # edge_weight = self.batch_compute_edge_weights(adj)
        return node_state,new_matrix,norm

    def get_output(self, features):
        if "single-class" in self.task_type:
            output = F.log_softmax(features, dim=-1)
        elif self.task_type in ["multi-class",  "reg"]:
            output = features
        elif self.task_type == "binary-class":
            output = features.flatten()
        elif "link" in self.task_type:
            outdim = features.shape[-1] // 2
            node_in = features[:, :outdim]
            node_out = features[:, outdim:]
            if "scale-dot" in self.task_type:
                output = torch.matmul(node_in, node_out.T) / outdim
            elif "cosine" in self.task_type:
                norm_in = torch.norm(node_in, dim=-1)
                norm_out = torch.norm(node_out, dim=-1)
                output = torch.matmul(node_in, node_out.T) / (norm_in * norm_out)
            output = output * 2
        else:
            raise ValueError("Unsupported task type " + self.task_type)
        return output

    def forward(self):
        pass


class N2Node(N2):
    def __init__(self, *, d_in=1, 
                          d_ein=0,
                          nclass=1, 
                          d_model=64, 
                          q_dim=64, 
                          n_q=8, 
                          n_c=8, 
                          n_pnode=256, 
                          T=1, 
                          task_type="single-class",
                          pre_encoder=None, 
                          pos_encoder=None, 
                          self_loop=True,
                          dropout=0.1):
        super(N2Node, self).__init__(d_in=d_in, 
                                     d_ein=d_ein,
                                     nclass=nclass, 
                                     d_model=d_model, 
                                     q_dim=q_dim, 
                                     n_q=n_q, 
                                     n_c=n_c, 
                                     n_pnode=n_pnode, 
                                     T=T, 
                                     task_type=task_type,
                                     pre_encoder=pre_encoder, 
                                     pos_encoder=pos_encoder, 
                                     self_loop=self_loop,
                                     dropout=dropout)
        self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.edge_embedding_layer = EdgeAndNodeEmbeddingLayer(d_in, q_dim)
        self.UpdateNodeEmbeddingLayer = UpdateNodeEmbeddingLayer(n_features=64)
        self.readout_layer = ReadoutLayer(n_features=64)

        # 新增参数注册部分
        self.register_parameter('pnode_state', None)  # 预注册参数占位
        self._init_pnode_params(n_pnode, q_dim)  # 初始化固定维度部分

    def _init_pnode_params(self, n_pnode, q_dim):
        """初始化与batch大小无关的参数部分"""
        if self.pnode_state is None:
            # 创建形状为(1, n_pnode, q_dim)的参数
            init_tensor = torch.randn(1, n_pnode, q_dim).to(self.device)
            self.pnode_state = nn.Parameter(init_tensor)

    def _get_pnode_state(self, batch_size):
        """动态扩展参数到当前batch的维度"""
        if self.pnode_state.dim() == 3:
            # 通过扩展操作适配batch维度 (不增加内存占用)
            return self.pnode_state.expand(batch_size, -1, -1)
        else:
            raise RuntimeError("Invalid pnode_state dimension")



    def forward(self, data):
        features, edge_weight,norm = self._feature_prep(data)
 
        # 特征转换为图节点状态空间
        node_state = self.node_state_interface(features)  # (b_s, n, q_dim)

        # 将节点特征映射到较高维空间
        features = self.feat_ff(features)

        b_s,n = features.shape[0:2]

        pnode_state = self._get_pnode_state(b_s)

        # if b_s==1:
        data_show(node_state, pnode_state,title="initial")

        for t in range(self.T):
            node_state, pnode_state, features = self.node_state_updater(
                                                                        edge_weight=edge_weight,
                                                                        features=features, 
                                                                        node_state=node_state, 
                                                                        pnode_state=pnode_state,
                                                                        norm=norm,
                                                                        )


        out = self.readout_layer(features)
        out = out.squeeze()
        return out

class ReadoutLayer(nn.Module):
    def __init__(self, n_features, n_hid=[], bias_pool=False, bias_readout=True):
        super().__init__()
        self.updatefeatures =nn.Linear(2*n_features, n_features, bias=bias_readout)
        self.layer_pooled = nn.Linear(int(n_features), int(n_features), bias=bias_pool)
        if type(n_hid) != list:
            n_hid = [n_hid]
        n_hid = [2 * n_features] + n_hid + [1]  # n_hid = [128, 1] (Bing)
        self.layers_readout = []
        for n_in, n_out in list(zip(n_hid, n_hid[1:])):
            layer = nn.Linear(n_in, n_out, bias=bias_readout)
            self.layers_readout.append(layer)
        self.layers_readout = nn.ModuleList(self.layers_readout)

    def forward(self, node_embeddings):
        node_embeddings = self.updatefeatures(node_embeddings)
        f_local = node_embeddings

        h_pooled = self.layer_pooled(node_embeddings.sum(dim=1) / node_embeddings.shape[1])

        f_pooled = h_pooled.repeat(1, 1, node_embeddings.shape[1]).view(node_embeddings.shape)  # 64*20*64

        features = F.relu(torch.cat([f_pooled, f_local], dim=-1))

        for i, layer in enumerate(self.layers_readout):
            features = layer(features)
            if i < len(self.layers_readout) - 1:
                features = F.relu(features)
            else:
                out = features
        return out

class UpdateNodeEmbeddingLayer(nn.Module):

    def __init__(self, n_features):
        super().__init__()

        self.message_layer = nn.Linear(2 * n_features, n_features, bias=False)
        self.update_layer = nn.Linear(3 * n_features, n_features, bias=False)

    def forward(self, current_node_embeddings, adj, norm):  # edge_embeddings,
        node_embeddings_aggregated = torch.matmul(adj, current_node_embeddings) / norm  # 64*20*64

        message = F.relu(
            self.message_layer(node_embeddings_aggregated))  # 64*20*64

        new_node_embeddings = F.relu(
            self.update_layer(torch.cat([current_node_embeddings, message], dim=-1)))  # 64*20*64

        return new_node_embeddings

class EdgeAndNodeEmbeddingLayer(nn.Module):

    def __init__(self, n_obs_in, n_features):
        super().__init__()
        self.n_obs_in = n_obs_in
        self.n_features = n_features

        self.edge_embedding_NN = nn.Linear(int(n_obs_in), n_features - 1, bias=False)  # in:5  out:63
        self.edge_feature_NN = nn.Linear(n_features, n_features, bias=False)

    def forward(self, node_features, adj, norm):

        edge_features = node_features.unsqueeze(-2).transpose(-2, -3).repeat(1, adj.shape[-2], 1, 1)  # 64*20*20*6;     node_features:64×20×6-->64×20×1×6-->64×1×20×6-->64×20×20×6(Bing)

        edge_features *= (adj.unsqueeze(-1) != 0).float()  # 边的权值为0，对应的edge_feature也为0

        edge_features_unrolled = torch.reshape(edge_features, (
        edge_features.shape[0], edge_features.shape[1] * edge_features.shape[1], edge_features.shape[-1]))

        embedded_edges_unrolled = F.relu(self.edge_embedding_NN(edge_features_unrolled))  # 64*400*63

        embedded_edges_rolled = torch.reshape(embedded_edges_unrolled,
                                              (adj.shape[0], adj.shape[1], adj.shape[1],
                                               self.n_features - 1))  # 64*20*20*63

        embedded_edges = embedded_edges_rolled.sum(dim=2) / norm  # 64*20*63

        edge_embeddings = F.relu(
            self.edge_feature_NN(torch.cat([embedded_edges, norm / norm.max()], dim=-1)))  # 64*20*64

        return edge_embeddings

