from typing import List
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor

from greatx.nn.layers import Sequential, SoftMedianConv, activations
from greatx.utils import wrapper


def personalized_page_rank(adj_matrix, teleport_proba=0.15, neighbors=64):
    """Simplified PPR implementation for GreatX"""
    device = adj_matrix.device
    N = adj_matrix.size(0)
    
    if isinstance(adj_matrix, SparseTensor):
        adj_matrix = adj_matrix.to_dense()
    
    # Add self-loops
    adj_matrix = adj_matrix + torch.eye(N, device=device)
    
    # Degree normalization (similar to GCN normalization)
    deg = adj_matrix.sum(dim=-1)
    deg.clamp_min_(1e-5)  # Avoid division by zero
    deg_sqrt = deg.sqrt()
    adj_norm = adj_matrix / deg_sqrt.unsqueeze(1) / deg_sqrt.unsqueeze(0)
    
    # PPR computation: alpha * (I - (1-alpha) * A_norm)^(-1)
    identity = torch.eye(N, device=device)
    ppr_matrix = torch.inverse(identity + 1e-6 - (1 - teleport_proba) * adj_norm)
    ppr_matrix = teleport_proba * ppr_matrix
    
    # Top-k neighbors selection
    if neighbors is not None and neighbors < N:
        top_values, top_indices = ppr_matrix.topk(neighbors, dim=-1)
        # Renormalize
        top_values = top_values / top_values.sum(dim=-1, keepdim=True).clamp_min(1e-5)
        ppr_sparse = torch.zeros_like(ppr_matrix)
        ppr_sparse.scatter_(-1, top_indices, top_values)
        return ppr_sparse
    else:
        # Row normalize
        return ppr_matrix / ppr_matrix.sum(dim=-1, keepdim=True).clamp_min(1e-5)


class SoftMedianGCN(nn.Module):
    r"""Graph Convolution Network (GCN) with
    soft median aggregation (MedianGCN)
    from the `"Robustness of Graph Neural Networks
    at Scale" <https://arxiv.org/abs/2110.14038>`_ paper
    (NeurIPS'21)

    Parameters
    ----------
    in_channels : int,
        the input dimensions of model
    out_channels : int,
        the output dimensions of model
    hids : List[int], optional
        the number of hidden units for each hidden layer,
        by default [16]
    acts : List[str], optional
        the activation function for each hidden layer,
        by default ['relu']
    dropout : float, optional
        the dropout ratio of model, by default 0.5
    bias : bool, optional
        whether to use bias in the layers,
        by default True
    temperature : float, optional
        temperature parameter for softmax weighting in soft median aggregation,
        by default 1.0
    normalize : bool, optional
        whether to compute symmetric normalization
        coefficients on the fly, by default False
    row_normalize : bool, optional
        whether to perform row-normalization on the fly,
        by default False
    cached : bool, optional
        whether the layer will cache
        the computation of :math:`(\mathbf{\hat{D}}^{-1/2}
        \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2})`
        and sorted edges on first execution,
        and will use the cached version for further executions,
        by default False
    bn: bool, optional
        whether to use :class:`BatchNorm1d` after the convolution layer,
        by default False

    Examples
    --------
    >>> # SoftMedianGCN with one hidden layer
    >>> model = SoftMedianGCN(100, 10)

    >>> # SoftMedianGCN with two hidden layers
    >>> model = SoftMedianGCN(100, 10, hids=[32, 16], acts=['relu', 'elu'])

    >>> # SoftMedianGCN with two hidden layers, without first activation
    >>> model = SoftMedianGCN(100, 10, hids=[32, 16], acts=[None, 'relu'])

    >>> # SoftMedianGCN with deep architectures, each layer has elu activation
    >>> model = SoftMedianGCN(100, 10, hids=[16]*8, acts=['elu'])

    See also
    --------
    :class:`greatx.nn.layers.SoftMedianConv`

    """
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [16], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 temperature: float = 1.0, normalize: bool = False, 
                 row_normalize: bool = False, cached: bool = True, 
                 bn: bool = False):

        super().__init__()

        conv = []
        assert len(hids) == len(acts)
        for hid, act in zip(hids, acts):
            conv.append(
                SoftMedianConv(in_channels, hid, bias=bias,
                               temperature=temperature, normalize=normalize,
                               row_normalize=row_normalize, cached=cached))
            if bn:
                conv.append(nn.BatchNorm1d(hid))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(dropout))
            in_channels = hid

        conv.append(
            SoftMedianConv(in_channels, out_channels, bias=bias,
                           temperature=temperature, normalize=normalize, 
                           row_normalize=row_normalize, cached=cached))
        self.conv = Sequential(*conv)

    def reset_parameters(self):
        self.conv.reset_parameters()
        self.cache_clear()

    def cache_clear(self):
        """Clear cached inputs or intermediate results."""
        for conv in self.conv:
            if hasattr(conv, '_cached_edges'):
                conv._cached_edges = None
        return self

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        return self.conv(x, edge_index, edge_weight)


class SoftMedianGDC(nn.Module):
    r"""Graph Convolution Network with Soft Median aggregation and 
    Personalized PageRank (GDC) preprocessing, following the GB implementation.
    
    This version includes PPR preprocessing as used in the original paper implementation.
    
    Parameters
    ----------
    in_channels : int
        the input dimensions of model
    out_channels : int
        the output dimensions of model
    hids : List[int], optional
        the number of hidden units for each hidden layer,
        by default [16]
    acts : List[str], optional
        the activation function for each hidden layer,
        by default ['relu']
    dropout : float, optional
        the dropout ratio of model, by default 0.5
    bias : bool, optional
        whether to use bias in the layers, by default True
    temperature : float, optional
        temperature parameter for softmax weighting, by default 1.0
    teleport_proba : float, optional
        teleport probability for PPR, by default 0.15
    neighbors : int, optional
        number of top neighbors to keep after PPR, by default 64
    cached : bool, optional
        whether to cache the PPR-processed adjacency matrix, by default True
    bn : bool, optional
        whether to use BatchNorm1d, by default False
    """
    
    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [16], acts: List[str] = ['relu'],
                 dropout: float = 0.5, bias: bool = True,
                 temperature: float = 1.0, teleport_proba: float = 0.15,
                 neighbors: int = 64, cached: bool = True, bn: bool = False):
        
        super().__init__()
        
        self.teleport_proba = teleport_proba
        self.neighbors = neighbors
        self.cached = cached
        self._cached_ppr_adj = None
        
        # Build the SoftMedianGCN layers
        conv = []
        assert len(hids) == len(acts)
        for hid, act in zip(hids, acts):
            conv.append(
                SoftMedianConv(in_channels, hid, bias=bias,
                               temperature=temperature, normalize=False,
                               row_normalize=False, cached=False,
                               add_self_loops=False))  # PPR already includes self-loops
            if bn:
                conv.append(nn.BatchNorm1d(hid))
            conv.append(activations.get(act))
            conv.append(nn.Dropout(dropout))
            in_channels = hid

        conv.append(
            SoftMedianConv(in_channels, out_channels, bias=bias,
                           temperature=temperature, normalize=False, 
                           row_normalize=False, cached=False,
                           add_self_loops=False))  # PPR already includes self-loops
        self.conv = Sequential(*conv)

    def reset_parameters(self):
        self.conv.reset_parameters()
        self.cache_clear()

    def cache_clear(self):
        """Clear cached inputs or intermediate results."""
        self._cached_ppr_adj = None
        for conv in self.conv:
            if hasattr(conv, '_cached_edges'):
                conv._cached_edges = None
        return self

    def _preprocess_adj(self, edge_index, edge_weight, num_nodes):
        """Apply PPR preprocessing to adjacency matrix"""
        if self.cached and self._cached_ppr_adj is not None:
            return self._cached_ppr_adj
            
        # Convert edge_index to dense adjacency matrix
        if edge_weight is None:
            edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
            
        adj_dense = torch.zeros(num_nodes, num_nodes, device=edge_index.device)
        adj_dense[edge_index[0], edge_index[1]] = edge_weight
        
        # Apply PPR preprocessing
        ppr_adj = personalized_page_rank(adj_dense, self.teleport_proba, self.neighbors)
        
        # Convert back to edge_index format
        ppr_edge_index = ppr_adj.nonzero().t()
        ppr_edge_weight = ppr_adj[ppr_edge_index[0], ppr_edge_index[1]]
        
        result = (ppr_edge_index, ppr_edge_weight)
        if self.cached:
            self._cached_ppr_adj = result
            
        return result

    def forward(self, x, edge_index, edge_weight=None):
        """Forward pass with PPR preprocessing"""
        num_nodes = x.size(0)
        
        # Apply PPR preprocessing
        ppr_edge_index, ppr_edge_weight = self._preprocess_adj(edge_index, edge_weight, num_nodes)
        
        # Apply SoftMedian convolutions with PPR-processed adjacency
        return self.conv(x, ppr_edge_index, ppr_edge_weight)