import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_scatter
from scipy.special import comb


class BernProp(nn.Module):
    """
    伯恩斯坦多项式传播层 (BernNet).

    使用伯恩斯坦多项式基函数，学习最优的传播系数，实现可控的消息传递。

    Args:
        K (int): 多项式的阶数，决定了BernNet的感受野大小
        cached (bool, optional): 是否缓存归一化邻接矩阵
    """

    def __init__(self, K, cached=True):
        super(BernProp, self).__init__()
        self.K = K
        self.cached = cached

        # 创建可学习的伯恩斯坦多项式系数
        # BernNet的核心在于这组可学习的传播系数
        self.coeffs = nn.Parameter(torch.Tensor(K + 1))
        self.reset_parameters()

        # 预计算伯恩斯坦多项式基函数的系数
        self.bern_coeff = torch.Tensor([(comb(K, i)) for i in range(K + 1)])

        # 用于缓存归一化邻接矩阵
        self.cached_norm_edge_weight = None
        self.cached_edge_index = None

    def reset_parameters(self):
        # 初始化为均匀分布
        nn.init.uniform_(self.coeffs, 0.0, 1.0)

    def forward(self, x, edge_index, edge_weight=None):
        """
        通过伯恩斯坦多项式传播算法传播特征

        Args:
            x (Tensor): 节点特征矩阵
            edge_index (LongTensor): 边索引
            edge_weight (Tensor, optional): 边权重

        Returns:
            Tensor: 传播后的特征
        """
        # 如果没有提供边权重，使用全1边权重
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # 准备归一化邻接矩阵
        row, col = edge_index
        N = x.size(0)

        # 检查是否需要重新计算归一化邻接矩阵
        if self.cached and self.cached_norm_edge_weight is not None and torch.equal(edge_index, self.cached_edge_index):
            edge_weight_norm = self.cached_norm_edge_weight
        else:
            # 计算度矩阵
            deg = torch_scatter.scatter_add(edge_weight, row, dim=0, dim_size=N)

            # 计算D^(-1/2)
            deg_inv_sqrt = deg.pow(-0.5)
            deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

            # 对称归一化的边权重 D^(-1/2) A D^(-1/2)
            edge_weight_norm = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

            # 如果启用缓存，则保存计算结果
            if self.cached:
                self.cached_norm_edge_weight = edge_weight_norm
                self.cached_edge_index = edge_index

        # 高效的消息传递实现
        def message_passing(x_k):
            return torch_scatter.scatter_add(
                edge_weight_norm.view(-1, 1) * x_k[row], col, dim=0, dim_size=N
            )

        # 使用伯恩斯坦多项式基函数计算输出
        device = x.device
        bern_coeff = self.bern_coeff.to(device)
        coeffs_softmax = F.softmax(self.coeffs, dim=0)

        # 初始化输出
        out = torch.zeros_like(x)

        # 计算 I - A_norm
        identity = torch.eye(N, device=device)

        # 计算所有T_k(x)表示
        Tx = [x]  # T_0(x) = x (自身特征)

        # T_1(x) = 对称归一化邻接矩阵传播一次
        x_prop = message_passing(x)
        Tx.append(x_prop)

        # 计算伯恩斯坦基函数 B_K^i(t) = C(K,i) * t^i * (1-t)^(K-i)
        # 其中t表示对称归一化邻接矩阵的固有值在[0,1]区间
        for k in range(self.K + 1):
            term = torch.zeros_like(x)
            for i in range(k + 1):
                # C(k,i) * T_i(x) * T_(k-i)(x)的组合
                if i <= len(Tx) - 1:
                    term = term + bern_coeff[i] * (Tx[i]) * (1 - i / self.K) ** (self.K - k)

            # 加权和
            out = out + coeffs_softmax[k] * term

            # 如果需要计算下一轮，继续传播
            if k < self.K and len(Tx) <= k + 1:
                x_prop = message_passing(x_prop)
                Tx.append(x_prop)

        return out


class BERNNET(nn.Module):
    def __init__(self, args, input_dim, output_dim, hid_dim):
        super(BERNNET, self).__init__()
        self.dropout = args.dropout
        self.K = args.K

        # 设置激活函数
        self.activation = args.activation_fn
        if self.activation == 'relu':
            self.activation_fn = F.relu
        elif self.activation == 'leaky_relu':
            self.activation_fn = F.leaky_relu
        elif self.activation == 'tanh':
            self.activation_fn = F.tanh
        elif self.activation == 'sigmoid':
            self.activation_fn = F.sigmoid
        else:
            raise ValueError(f"Unsupported activation function: {self.activation}")

        # 阈值过滤
        self.threshold = args.threshold if hasattr(args, 'threshold') else 0.0

        # 初始化MLPs (两层神经网络，作为特征转换)
        # BernNet使用MLP先变换特征，再用Bernstein传播
        self.lin1 = nn.Linear(input_dim, hid_dim)
        self.lin2 = nn.Linear(hid_dim, output_dim)

        # 伯恩斯坦传播层
        self.prop = BernProp(
            K=self.K,
            cached=True
        )

        if hasattr(args, 'reset_param') and args.reset_param:
            self.reset_parameter()
        elif hasattr(args, 'rest_param') and args.rest_param:
            self.reset_parameter()

    def reset_parameter(self):
        nn.init.xavier_uniform_(self.lin1.weight.data)
        if self.lin1.bias is not None:
            self.lin1.bias.data.zero_()

        nn.init.xavier_uniform_(self.lin2.weight.data)
        if self.lin2.bias is not None:
            self.lin2.bias.data.zero_()

        self.prop.reset_parameters()

    def forward(self, data):
        data = self.filter_edges_by_threshold(data, self.threshold)
        x, edge_index = data.x, data.edge_index

        if hasattr(data, 'edge_weight') and data.edge_weight is not None:
            edge_weight = data.edge_weight
        else:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)

        # 特征转换
        x = self.lin1(x)
        x = self.activation_fn(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        # 第二层变换
        x = self.lin2(x)

        # 伯恩斯坦传播
        h = self.prop(x, edge_index, edge_weight)

        # 返回logits和中间特征
        return h, x

    def filter_edges_by_threshold(self, data, threshold):
        """
        Filter edges with weights below the given threshold.

        Args:
            data: The graph data object containing edge_index and edge_weight
            threshold: The weight threshold below which edges will be removed

        Returns:
            Updated data object with filtered edges
        """
        # Check if edge weights exist
        if not hasattr(data, 'edge_weight') or data.edge_weight is None:
            # If no edge weights, return the original data
            return data

        # Create mask for edges to keep (where weight >= threshold)
        mask = data.edge_weight >= threshold

        # Apply mask to filter both edge_index and edge_weight
        data.edge_index = data.edge_index[:, mask]
        data.edge_weight = data.edge_weight[mask]

        return data