import torch
from typing import Optional
from torch import Tensor
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    SparseTensor,
)
from  torch_geometric.nn.conv.gcn_conv import gcn_norm
class Graph_Conv(MessagePassing):
    _cached_edge_index: Optional[OptPairTensor]
    _cached_adj_t: Optional[SparseTensor]
    def __init__(
        self,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: Optional[bool] = None,
        normalize: bool = True,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        if add_self_loops is None:
            add_self_loops = normalize
        if add_self_loops and not normalize:
            raise ValueError(f"'{self.__class__.__name__}' does not support "
                             f"adding self-loops to the graph when no "
                             f"on-the-fly normalization is applied")
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.normalize = normalize
        self._cached_edge_index = None
        self._cached_adj_t = None

    def reset_parameters(self):
        super().reset_parameters()
        self._cached_edge_index = None
        self._cached_adj_t = None

    def forward(self, x: Tensor, edge_index: Adj,
                edge_weight: OptTensor = None) -> Tensor:
        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim))
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]
            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(
                        edge_index, edge_weight, x.size(self.node_dim))
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        return out
    def message(self, x_j: Tensor, edge_weight: OptTensor) -> Tensor:
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j
    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
        return spmm(adj_t, x, reduce=self.aggr)


@torch.no_grad()
def prune_edge_index_by_attn(
    attn_bN: torch.Tensor,   
    attn_Nb: torch.Tensor,   
    edge_index: torch.Tensor, 
    keep_ratio: float = 0.5, 
    keep_self_loops: bool = True
):

    row, col = edge_index[0], edge_index[1]

    num_nodes = max(row.max(), col.max()).item() + 1

    if attn_Nb.shape[0] < num_nodes and attn_Nb.shape[1] >= num_nodes:
        attn_Nb = attn_Nb.t()
        
    if attn_bN.shape[1] < num_nodes and attn_bN.shape[0] >= num_nodes:
        attn_bN = attn_bN.t()

    src_vec = attn_Nb[row] 
    
    dst_vec = attn_bN[:, col].t()
    
    edge_score = (src_vec * dst_vec).sum(dim=-1)

    if keep_self_loops:
        mask_self = (row == col)        
        mask_non_self = ~mask_self    
        scores_for_ranking = edge_score[mask_non_self]

        num_non_self = scores_for_ranking.numel()
        if num_non_self > 0:
            k = max(1, int(num_non_self * keep_ratio))
            threshold = torch.topk(scores_for_ranking, k).values[-1]
            mask_keep = (edge_score >= threshold) & mask_non_self
            mask_keep = mask_keep | mask_self
        else:
            mask_keep = torch.ones_like(edge_score, dtype=torch.bool)
    else:
        num_edges = edge_index.size(1)
        k = max(1, int(num_edges * keep_ratio))
        threshold = torch.topk(edge_score, k).values[-1]
        mask_keep = edge_score >= threshold

    return edge_index[:, mask_keep]
