import torch, numpy as np, scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F

def _to_dense_torch(mat, device):
    import numpy as np, scipy.sparse as sp, torch
    if isinstance(mat, np.ndarray):
        arr = mat
    elif sp.issparse(mat):
        arr = mat.toarray()
    else:
        arr = np.asarray(mat)
    return torch.as_tensor(arr, dtype=torch.float32, device=device)

def unpool_one_level(H_coarse, clusters, N_fine):
    device = H_coarse.device
    D = H_coarse.size(1)
    H_fine = torch.zeros(N_fine, D, device=device)
    for i, child_idx in enumerate(clusters):
        if len(child_idx) == 0: continue
        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)
        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))
    return H_fine

def unpool_to_level0(H_l, level_l, treeG):
    H = H_l
    for m in range(level_l, 0, -1):
        clusters_m = treeG[m]['clusters']
        N_fine     = treeG[m-1]['adj'].shape[0]
        H = unpool_one_level(H, clusters_m, N_fine)
    return H

class HaarSpectralBlock(nn.Module):
    def __init__(self, max_K: int):
        super().__init__()
        self.lambda_vec = nn.Parameter(torch.randn(max_K))

    def forward(self, U: torch.Tensor, X: torch.Tensor):
        K_l   = U.size(1)
        K_cap = min(K_l, self.lambda_vec.size(0))
        Uc    = U[:, :K_cap]
        X_hat = Uc.transpose(0, 1) @ X
        lam   = self.lambda_vec[:K_cap].unsqueeze(1)
        X_hat = X_hat * lam
        H     = Uc @ X_hat
        return F.relu(H)

     return F.relu(H)

class FusionGate(nn.Module):
    """
    Tiny attention gate to fuse per-node raw + Haar features.
    Acts like a residual controller: fused = (1-a)*raw + a*proj([raw, haar]).
    """
    def __init__(self, hid_dim):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(1))           # starts near raw
        self.proj  = nn.Linear(2 * hid_dim, hid_dim)

    def forward(self, raw, haar):
        z = torch.cat([raw, haar], dim=-1)                  # [N, 2H]
        fused = self.proj(z)                                 # [N, H]
        gate  = torch.sigmoid(self.alpha)                    # scalar in (0,1)
        return (1 - gate) * raw + gate * fused

class NodeHaarClassifierAggressive(nn.Module):
    """
    Node-classification–oriented:
      - retains raw level-0 features
      - adds multi-scale Haar features (levels 1..L-1), unpooled to level 0
      - fuses progressively via attention gates (no GCN)
      - LayerNorm + 2-layer MLP head
    """
    def __init__(self, in_dim: int, hid_dim: int, num_classes: int, max_K: int, num_levels: int,
                 p_drop=0.5):
        super().__init__()
        self.num_levels = num_levels

        # Preprocess any level's features into hid_dim
        self.pre = nn.Sequential(
            nn.Linear(in_dim, hid_dim),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hid_dim, hid_dim)
        )

        self.block = HaarSpectralBlock(max_K=max_K)
        # gates for levels 1..num_levels-1
        self.gates = nn.ModuleList([FusionGate(hid_dim) for _ in range(max(0, num_levels-1))])

        self.norm = nn.LayerNorm(hid_dim)
        self.head = nn.Sequential(
            nn.Linear(hid_dim, hid_dim),
            nn.ReLU(),
            nn.Dropout(p_drop),
            nn.Linear(hid_dim, num_classes)
        )

    def forward(self, U_list, features_list, treeG):
        # Safety on levels
        L_eff = min(self.num_levels, len(U_list), len(features_list), len(treeG))
        dev0  = device  # use global device picked above

        # (A) raw level-0 path
        X0   = _to_dense_torch(features_list[0], dev0)  # [N0, Fin]
        fused = self.pre(X0)                            # [N0, H]

        # (B) fuse in Haar features from higher levels
        # we iterate levels 1..L_eff-1; gate index aligns with l-1
        for l in range(1, L_eff):
            X_l = _to_dense_torch(features_list[l], dev0)
            X_l = self.pre(X_l)                                    # [Nl, H]
            U_l = _to_dense_torch(U_list[l], dev0)                 # [Nl, Kl]
            H_l = self.block(U_l, X_l)                             # [Nl, H]
            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)   # [N0, H]
            fused = self.gates[l-1](fused, H0_l)                   # [N0, H]

        # (C) normalize + classify
        fused  = self.norm(fused)          # [N0, H]
        logits = self.head(fused)          # [N0, C]
        return logits

