import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_scatter


class GPRProp(nn.Module):
    """
    广义PageRank传播层 (GPR-GNN).

    学习最优的传播系数，实现可控的消息传递。

    Args:
        K (int): 传播次数，决定了GPR-GNN的感受野大小
        alpha (float, optional): 初始化传播系数的参数，影响初始化的倾向性
        cached (bool, optional): 是否缓存归一化邻接矩阵
    """

    def __init__(self, K, alpha=0.1, cached=True):
        super(GPRProp, self).__init__()
        self.K = K
        self.alpha = alpha
        self.cached = cached

        # 创建可学习的传播系数
        # GPR-GNN的核心在于这组可学习的传播系数
        self.temp = nn.Parameter(torch.Tensor(K + 1))
        self.reset_parameters()

        # 用于缓存归一化邻接矩阵
        self.cached_norm_edge_weight = None
        self.cached_edge_index = None

    def reset_parameters(self):
        # 使用预设的alpha参数来初始化传播系数
        # 初始化为(1-alpha) * alpha^k的形式，类似PageRank的衰减系数
        for k in range(self.K + 1):
            self.temp.data[k] = self.alpha * (1 - self.alpha) ** k

        # 使系数归一化为总和为1
        self.temp.data = self.temp.data / self.temp.data.sum()

    def forward(self, x, edge_index, edge_weight=None):
        """
        通过广义PageRank传播算法传播特征

        Args:
            x (Tensor): 节点特征矩阵
            edge_index (LongTensor): 边索引
            edge_weight (Tensor, optional): 边权重

        Returns:
            Tensor: 传播后的特征
        """
        # 如果没有提供边权重，使用全1边权重
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # 准备归一化邻接矩阵
        row, col = edge_index
        N = x.size(0)

        # 检查是否需要重新计算归一化邻接矩阵
        if self.cached and self.cached_norm_edge_weight is not None and torch.equal(edge_index, self.cached_edge_index):
            edge_weight_norm = self.cached_norm_edge_weight
        else:
            # 计算度矩阵
            deg = torch_scatter.scatter_add(edge_weight, row, dim=0, dim_size=N)

            # 计算D^(-1/2)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

            # 对称归一化的边权重 D^(-1/2) A D^(-1/2)
            edge_weight_norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

            # 如果启用缓存，则保存计算结果
            if self.cached:
                self.cached_norm_edge_weight = edge_weight_norm
                self.cached_edge_index = edge_index

        # 高效的消息传递实现
        def message_passing(x_k):
            return torch_scatter.scatter_add(
                edge_weight_norm.view(-1, 1) * x_k[row], col, dim=0, dim_size=N
            )

        # 初始化输出为加权的自环 (H_0)
        out = self.temp[0] * x

        # 进行K次传播
        x_k = x
        for k in range(1, self.K + 1):
            # 通过消息传递得到下一步的特征
            x_k = message_passing(x_k)
            # 将特征加权后累加到输出
            out = out + self.temp[k] * x_k

        return out


class GPRGNN(nn.Module):
    def __init__(self, args, input_dim, output_dim, hid_dim):
        super(GPRGNN, self).__init__()
        self.dropout = args.dropout
        self.K = args.K if hasattr(args, 'K') else 10
        self.threshold = args.threshold
        # 设置激活函数
        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        # 初始化MLPs (两层神经网络，作为特征转换)
        # GPRGNN使用MLP先变换特征，再用GPR传播
        self.lin1 = nn.Linear(input_dim, hid_dim)
        self.lin2 = nn.Linear(hid_dim, output_dim)

        # GPR传播层
        self.prop = GPRProp(
            K=self.K,
            alpha=args.alpha if hasattr(args, 'alpha') else 0.1,
            cached=True
        )

        if hasattr(args, 'reset_param') and args.reset_param:
            self.reset_parameter()
        elif hasattr(args, 'rest_param') and args.rest_param:
            self.reset_parameter()

    def reset_parameter(self):
        nn.init.xavier_uniform_(self.lin1.weight.data)
        if self.lin1.bias is not None:
            self.lin1.bias.data.zero_()

        nn.init.xavier_uniform_(self.lin2.weight.data)
        if self.lin2.bias is not None:
            self.lin2.bias.data.zero_()

        self.prop.reset_parameters()

    def forward(self, data):

        data = self.filter_edges_by_threshold(data, self.threshold)

        x, edge_index = data.x, data.edge_index

        if hasattr(data, 'edge_weight') and data.edge_weight is not None:
            edge_weight = data.edge_weight
        else:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # 特征转换
        x = self.lin1(x)
        x = self.activation_fn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # 第二层变换
        x = self.lin2(x)

        # GPR传播
        h = self.prop(x, edge_index, edge_weight)

        # 返回logits和中间特征
        return h, x
    def filter_edges_by_threshold(self, data, threshold):
        """
        Filter edges with weights below the given threshold.

        Args:
            data: The graph data object containing edge_index and edge_weight
            threshold: The weight threshold below which edges will be removed

        Returns:
            Updated data object with filtered edges
        """
        # Check if edge weights exist
        if not hasattr(data, 'edge_weight') or data.edge_weight is None:
            # If no edge weights, return the original data
            return data

        # Create mask for edges to keep (where weight >= threshold)
        mask = data.edge_weight >= threshold

        # Apply mask to filter both edge_index and edge_weight
        data.edge_index = data.edge_index[:, mask]
        data.edge_weight = data.edge_weight[mask]

        return data