import numpy as np
import scipy.sparse as sp
from sklearn.cluster import KMeans
import torch
from .utils import pyg_to_scipy_adj

@torch.no_grad()
def Make_tree_real2(X, edge_index, gnn_model, levels, ratio=0.2, temp=0.1, tau=0.5):
    """
    Hierarchical coarsening with margin-weighted soft assignments (NumPy for speed).
    Returns:
      treeG: list of dicts with keys: IDX, clusters, adj (scipy csr), features (np)
      S_assign_list: list of np arrays [N_l, K_l]
    """
    if isinstance(X, torch.Tensor):
        X = X.detach().cpu().numpy()
    A = pyg_to_scipy_adj(edge_index, num_nodes=X.shape[0]).tocsr()

    adj_list = [A]
    features_list = [X]
    parents = []
    S_assign_list = []
    N_start = X.shape[0]

    for level in range(levels - 1):
        # Node scores from encoder
        z = gnn_model(torch.tensor(X, dtype=torch.float32), edge_index).cpu().numpy()
        N_cur = z.shape[0]

        # Margin score
        pos_mask = (z > tau).astype(np.float32)
        neg_mask = 1.0 - pos_mask
        pos_sum = (z * pos_mask).sum(axis=1)
        neg_sum = (z * neg_mask).sum(axis=1)
        counts = pos_mask.sum(axis=1) + neg_mask.sum(axis=1) + 1e-6
        margin = (pos_sum - neg_sum) / counts  # [N_cur]

        # #clusters
        if N_cur <= 2 or level == levels - 2:
            K = 1
        else:
            K = int(N_cur * ratio) + 1

        # KMeans
        kmeans = KMeans(n_clusters=K, random_state=42).fit(z)
        C = kmeans.cluster_centers_
        hard = kmeans.labels_

        # distances and logits
        d2 = ((z[:, None, :] - C[None, :, :])**2).sum(axis=2)
        logits = (-d2 / temp) * margin[:, None]
        m = logits.max(axis=1, keepdims=True)
        S = np.exp(logits - m)
        S = S / S.sum(axis=1, keepdims=True)  # [N_cur, K]

        # coarsen adjacency by hard labels
        rr, cc = A.nonzero()
        vv = np.asarray(A[rr, cc]).reshape(-1)
        nrr, ncc = hard[rr], hard[cc]
        A = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))
        adj_list.append(A)

        # coarsen features: S^T X
        X = S.T @ X
        features_list.append(X)

        parents.append(hard)
        S_assign_list.append(S)

    # build treeG
    treeG = [None] * levels
    for lvl in range(levels):
        if lvl == 0:
            idxs = np.arange(N_start)
            clusters = [[i] for i in idxs]
            IDX = idxs
        else:
            pid = parents[lvl - 1]
            order = np.argsort(pid)
            _, idx0 = np.unique(pid[order], return_index=True)
            clusters = np.split(order, idx0[1:])
            IDX = pid
        treeG[lvl] = dict(IDX=IDX, clusters=clusters, adj=adj_list[lvl], features=features_list[lvl])
    return treeG, S_assign_list

def HaarGOB_with_Sassign(treeG, S_assign_list):
    """
    Build a simple orthonormal “global-then-local” basis per level using cluster structure.
    Stores list of basis vectors in treeG[l]['u'] as np arrays (length N_l each).
    """
    Ntr = len(treeG)
    clusterJ0 = treeG[Ntr-1]['clusters']
    N0 = len(clusterJ0)

    # top-level: DC + hierarchical differences
    chic = np.identity(N0)
    uc = [None] * N0
    uc[0] = 1/np.sqrt(N0) * np.ones(N0)
    for l in range(1, N0):
        uc[l] = np.sqrt((N0 - l) / (N0 - l + 1)) * (chic[l-1, :] - 1/(N0 - l) * np.sum(chic[l:, :], axis=0))
    treeG[Ntr-1]['u'] = uc

    # propagate down
    for j_tr in range(Ntr-2, -1, -1):
        N1 = len(treeG[j_tr]['clusters'])
        u = [None] * N1
        i = N0
        S_assign = S_assign_list[j_tr] if j_tr < len(S_assign_list) else None

        for l in range(N0):
            clusterl = treeG[j_tr+1]['clusters'][l]
            kl = len(clusterl)
            ucl = uc[l]

            ul1 = np.zeros(N1)
            for j in range(N0):
                idxj = treeG[j_tr+1]['clusters'][j]
                w = S_assign[idxj, l] if S_assign is not None else 1.0
                ul1[idxj] = ucl[j] * w
            u[l] = ul1 / max(np.sqrt(kl), 1.0)

            if kl > 1:
                chil = np.zeros((kl, N1))
                idxl = treeG[j_tr+1]['clusters'][l]
                for k in range(kl):
                    chil[k, idxl[k]] = 1
                for k in range(1, kl):
                    i += 1
                    ulk = np.sqrt((kl - k) / (kl - k + 1)) * (chil[k-1, :] - 1/(kl - k) * np.sum(chil[k:, :], axis=0))
                    u[i-1] = ulk

        treeG[j_tr]['u'] = u
        uc = u
        N0 = N1
    return treeG
