import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor
from .utils import _to_dense_torch

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_sparse import SparseTensor
#heterophilous encoder with attention for heterogeneous graphs
class Hetero_Graph_Attention_Layer(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.1, alpha=0.2, num_layers=1):
        super(Hetero_Graph_Attention_Layer, self).__init__()
        self.dropout = dropout
        self.in_features = in_features
        self.out_features = out_features
        self.alpha = alpha
        self.num_layers = num_layers

        # Learnable weight matrix for node embeddings
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        # Multi-layer perceptron for attention weights
        self.attention_mlp = nn.Sequential(
            nn.Linear(2 * out_features, 16),
            nn.ReLU(),
            nn.Linear(16, 1)
        )
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        for layer in self.attention_mlp:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight, gain=1.414)
                if layer.bias is not None:
                    nn.init.constant_(layer.bias, 0)

    def forward(self, h, edge_index):
        """
        h: [num_nodes, in_features] (NO batch)
        edge_index: [2, E]
        """
        # Project features
        Wh = torch.matmul(h, self.W)  # [N, Fout]

        # Build sparse adjacency from edges
        num_nodes = Wh.size(0)
        adj = self._edge_index_to_adj(edge_index, num_nodes)
        print("adj", adj)

        # (i) Feature attention on edges -> sparse matrix
        similarity_scores = self._compute_attention_scores(Wh, edge_index)  # sparse [N,N]
        similarity_softmax = self._sparse_softmax(similarity_scores)

        # (ii) Structure similarity (Jaccard) on edges -> sparse matrix
        structure_similarity = self.compute_structure_similarity_scores(edge_index, num_nodes)  # sparse [N,N]
        structure_similarity_softmax = self._sparse_softmax(structure_similarity)

        # Combine
        attention_scores = similarity_softmax + structure_similarity_softmax  # sparse [N,N]
        #print("similarity_softmax", similarity_softmax)
        #print("structure_similarity_softmax", structure_similarity_softmax)
        #print("attention_scores", attention_scores)

        # Adaptive adjacency (convert to dense for formula, then back to sparse for mm)
        adaptive_adj = self._compute_adaptive_adj(adj, attention_scores.to_dense())  # sparse [N,N]

        # Message passing (same loop shape as your code)
        h_prime = Wh
        for _ in range(self.num_layers):
            h_prime = torch.sparse.mm(adaptive_adj, h_prime)  # [N, Fout]
            h_prime = F.dropout(h_prime, p=self.dropout, training=self.training)

        return h_prime  # [N, Fout]

    def _edge_index_to_adj(self, edge_index, num_nodes):
        values = torch.ones(edge_index.size(1), device=edge_index.device)
        adj = torch.sparse_coo_tensor(edge_index, values, (num_nodes, num_nodes))
        return adj.coalesce()  # sorts & sums duplicates

    def _compute_adaptive_adj(self, adj, S_class):
        """
        adj: sparse [N,N]; S_class: dense [N,N] in [0,1]
        hetero_adj = S_class * A + (1 - S_class) * (I - A)
        Return sparse for sparse mm.
        """
        N = adj.size(0)
        A = adj.to_dense()                                # [N,N]
        I = torch.eye(N, device=A.device)                 # [N,N]
        hetero = S_class * A + (1.0 - S_class) * (I - A)  # [N,N] dense
        return hetero.to_sparse().coalesce()              # sparse [N,N]

    def _compute_attention_scores(self, Wh, edge_index):
        """
        Wh: [N, Fout]; edge_index: [2, E]
        Return sparse scores with values only on given edges.
        """
        src_idx, dst_idx = edge_index[0], edge_index[1]     # [E]
        src = Wh[src_idx, :]                                 # [E, Fout]
        dst = Wh[dst_idx, :]                                 # [E, Fout]
        attention_input = torch.cat([src, dst], dim=-1)      # [E, 2*Fout]

        attention_scores = self.attention_mlp(attention_input).squeeze(-1)  # [E]
        print("attention_scores", attention_scores.shape)

        num_nodes = Wh.size(0)
        attention_scores_sparse = torch.sparse_coo_tensor(
            edge_index, attention_scores, (num_nodes, num_nodes)
        ).coalesce()
        return attention_scores_sparse

    def _sparse_softmax(self, scores_sparse):
        """
        Row-wise softmax over neighbors (convert to dense for simplicity).
        """
        dense = scores_sparse.to_dense() if scores_sparse.is_sparse else scores_sparse
        dense_softmax = torch.softmax(dense, dim=-1)
        return dense_softmax.to_sparse().coalesce()

    def compute_structure_similarity_scores(self, edge_index, num_nodes): #structure similarity using two hop neighbors
        """
        Jaccard similarity per edge, returned as a sparse matrix on those edges.
        """
        row, col = edge_index[0].long(), edge_index[1].long()

        # Sparse adjacency (torch_sparse)
        adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))

        # Common neighbors and degree
        # NOTE: convert to dense BEFORE advanced indexing to get 1-D picks.
        common_dense = (adj @ adj.t()).to_dense()   # [N, N]
        degree = adj.sum(dim=1).to_dense()          # [N]

        # Element-wise picks for each edge (u=row[i], v=col[i]) -> [E]
        cn_edge = common_dense[row, col]            # [E]
        deg_row = degree[row]                       # [E]
        deg_col = degree[col]                       # [E]

        total = deg_row + deg_col - cn_edge
        jacc = (cn_edge / total.clamp_min(1e-9)).contiguous().view(-1)  # ensure 1-D [E]

        # Build sparse scores on the given edges
        jacc_sparse = torch.sparse_coo_tensor(
            torch.stack([row, col], dim=0), jacc, (num_nodes, num_nodes)
        ).coalesce()
        return jacc_sparse

    # (kept for parity with your snippet)
    def compute_adaptive_adj(adj, S_class): #Adaptive Adjacency
        identity = torch.eye(adj.size(0), device=adj.device)
        sparse_I = torch.sparse_coo_tensor(torch.arange(adj.size(0), device=adj.device).repeat(2,1),
                                           torch.ones(adj.size(0), device=adj.device),
                                           adj.size())
        hetero_adj = S_class * adj + (1 - S_class) * (sparse_I - adj)
        return hetero_adj

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.sparse as sp

# ---- utilities ----
def _to_dense_torch(mat, device):
    """numpy or scipy.spmatrix -> torch.float32 on device"""
    if isinstance(mat, np.ndarray):
        arr = mat
    elif sp.issparse(mat):
        arr = mat.toarray()
    else:
        arr = np.asarray(mat)
    return torch.as_tensor(arr, dtype=torch.float32, device=device)

def unpool_one_level(H_coarse, clusters, N_fine):
    """
    Scatter coarser features H_coarse [N_coarse, D] to finer level of size N_fine
    using 'clusters' (list of child index arrays for each coarse node).
    Returns H_fine [N_fine, D].
    """
    device = H_coarse.device
    D = H_coarse.size(1)
    H_fine = torch.zeros(N_fine, D, device=device)
    # clusters length == N_coarse; clusters[i] are indices at the finer level
    for i, child_idx in enumerate(clusters):
        if len(child_idx) == 0:
            continue
        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)
        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))
    return H_fine

def unpool_to_level0(H_l, level_l, treeG):
    """
    Recursively unpool H_l from level 'level_l' down to level 0 using treeG[level]['clusters'].
    treeG[level]['clusters'] is a list where element i holds the child indices at level-1.
    """
    H = H_l
    for m in range(level_l, 0, -1):
        clusters_m = treeG[m]['clusters']           # children at level m-1
        N_fine     = treeG[m-1]['adj'].shape[0]
        H = unpool_one_level(H, clusters_m, N_fine) # now at level m-1
    return H  # now at level 0

# ---- spectral block: U @ (diag(lambda) @ (U^T X)) ----
class HaarSpectralBlock(nn.Module):
    def __init__(self, max_K: int):
        super().__init__()
        self.lambda_vec = nn.Parameter(torch.randn(max_K))

    def forward(self, U: torch.Tensor, X: torch.Tensor):
        # U: [N_l, K_l] (dense), X: [N_l, F] -> H: [N_l, F]
        X_hat = U.transpose(0, 1) @ X               # [K_l, F]
        K_l = X_hat.size(0)
        lam = self.lambda_vec[:K_l].unsqueeze(1)    # [K_l, 1]
        X_hat = X_hat * lam
        H = U @ X_hat                 # [N_l, F] #spectral convolution
        return F.relu(H)

# ---- node classifier that aggregates all levels at level 0 ----
class NodeHaarUnpoolClassifier(nn.Module):
    """
    For one graph:
      - Applies a shared spectral block per level.
      - Unpools each level’s features to level 0 using treeG[level]['clusters'].
      - Concatenates per-level contributions at level 0 and classifies nodes.
    """
    def __init__(self, in_dim: int, hid_dim: int, num_classes: int, max_K: int, num_levels: int):
        super().__init__()
        self.num_levels = num_levels    # how many levels to use (typically L-1; skip last 1-node level)
        self.pre = nn.Linear(in_dim, hid_dim)
        self.block = HaarSpectralBlock(max_K=max_K)
        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)
        self.dropout = nn.Dropout(p=0.3)

    def forward(self, U_list, features_list, treeG):
        """
        U_list:        list of [N_l, K_l] (numpy/scipy or torch), usually levels 0..L-2
        features_list: list of [N_l, Fin]  (same levels)
        treeG:         list of dicts with 'clusters' and 'adj' for levels 0..L-1
        Returns: logits over nodes at level 0, shape [N0, num_classes]
        """
        device = next(self.parameters()).device
        L_eff = min(self.num_levels, len(U_list))   # safety

        # Preproject features at each level to hidden, run spectral block
        H_per_level = []
        for l in range(L_eff):
            X_l = _to_dense_torch(features_list[l], device)   # [N_l, Fin]
            X_l = self.dropout(F.relu(self.pre(X_l)))         # [N_l, H]
            U_l = _to_dense_torch(U_list[l], device)          # [N_l, K_l]
            H_l = self.block(U_l, X_l)                        # [N_l, H]
            # Unpool to level 0
            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)  # [N0, H]
            H_per_level.append(H0_l) # H_per_level[l]: [N0, H] for level l

        # Concatenate per-level contributions at level 0
        H0_cat = torch.cat(H_per_level, dim=1)      # [N0, H * L_eff]
        H0_cat = self.dropout(H0_cat) # [N0, H * L_eff]
        logits = self.classifier(H0_cat)            # [N0, C]
        return F.softmax(logits, dim=0)  # return probabilities over classes
