import torch
import torch.nn as nn
from torch import Tensor
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.utils import add_self_loops


class APPNPPropagation(MessagePassing):
    """
    Implementation of APPNP propagation layer from the paper:
    "Predict then Propagate: Graph Neural Networks meet Personalized PageRank"
    (Klicpera et al., ICLR 2019)
    """

    def __init__(self, K, alpha, add_self_loops=True, cached=False, normalization='sym'):
        """
        Initializes a new APPNP propagation layer.
        
        Parameters
        ----------
        K: int
            Number of propagation steps
        alpha: float
            Teleport probability (alpha in the paper)
        add_self_loops: bool, default: True
            Whether to add self-loops to the graph
        cached: bool, default: False
            Whether to cache the normalized adjacency matrix
        normalization: str, default: 'sym'
            Type of normalization to use ('sym' for symmetric, 'rw' for random walk)
        """
        super().__init__(aggr='add')
        self.K = K
        self.alpha = alpha
        self.add_self_loops = add_self_loops
        self.cached = cached
        self.normalization = normalization
        self.cached_adj = None

    def reset_parameters(self):
        self.cached_adj = None

    def forward(self, x, edge_index, edge_weight=None):
        """
        Forward pass of the APPNP propagation.
        
        Parameters
        ----------
        x: Tensor
            Node features
        edge_index: Tensor or SparseTensor
            Edge indices
        edge_weight: Tensor, optional
            Edge weights
            
        Returns
        -------
        Tensor
            The propagated node features
        """
        if self.cached and self.cached_adj is not None:
            adj_t = self.cached_adj
        else:
            if isinstance(edge_index, SparseTensor):
                adj_t = edge_index
            else:
                adj_t = SparseTensor(
                    row=edge_index[0], col=edge_index[1],
                    value=edge_weight, sparse_sizes=(x.size(0), x.size(0))
                )

            if self.add_self_loops:
                adj_t = adj_t.set_diag()

            # Implement GCN normalization manually since gcn_norm doesn't accept normalization param
            if self.normalization == 'sym':
                # Symmetric normalization: D^{-1/2} A D^{-1/2}
                deg = adj_t.sum(dim=1).pow(-0.5)
                deg.masked_fill_(deg == float('inf'), 0)
                adj_t = adj_t * deg.view(-1, 1)
                adj_t = adj_t * deg.view(1, -1)
            elif self.normalization == 'rw':
                # Random walk normalization: D^{-1} A
                deg = adj_t.sum(dim=1).pow(-1)
                deg.masked_fill_(deg == float('inf'), 0)
                adj_t = adj_t * deg.view(-1, 1)
            else:
                # If no normalization or other types
                pass

            if self.cached:
                self.cached_adj = adj_t

        h = x
        z = x.clone()

        for k in range(self.K):
            # Propagate features
            h = self.propagate(adj_t, x=h)
            # Add teleport term with original features
            h = (1 - self.alpha) * h + self.alpha * z

        return h

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        return matmul(adj_t, x) 