{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.sparse.linalg\n",
    "import scipy.spatial.distance\n",
    "from sklearn.cluster import SpectralClustering\n",
    "import scipy\n",
    "import scipy.sparse as sp\n",
    "import numpy as np\n",
    "import copy\n",
    "import torch\n",
    "from scipy.sparse import coo_matrix\n",
    "def adj2edge(adj):\n",
    "    \"\"\"\n",
    "    adj2edge computes edge_index and edge_weights for adjacency matrix adj\n",
    "    INPUTS: \n",
    "        adj         - adjacency matrix of a graph, in coo_matrix format\n",
    "    OUTPUTS:\n",
    "        edge_index  - list for indices of edge ends  \n",
    "        edge_weight - non-zero values/elements in adj\n",
    "    \"\"\"\n",
    "    adj = adj.tocoo().astype(np.float64)\n",
    "    row = adj.row\n",
    "    col = adj.col\n",
    "    values = adj.data\n",
    "    edge_weights = torch.Tensor(values)\n",
    "    edge_index = torch.LongTensor([list(row),list(col)])\n",
    "    return edge_index, edge_weights\n",
    "    \n",
    "def edge2adj(edge_index,edge_weight,num_nodes):\n",
    "    \"\"\"\n",
    "    edge2adj computes adjacency matrix by edge_index and edge_weights\n",
    "    INPUTS: \n",
    "        edge_index  - list for indices of edge ends\n",
    "        edge_weight - non-zero values/elements in adj    \n",
    "    OUTPUTS:\n",
    "        adj         - adjacency matrix\n",
    "    \"\"\"\n",
    "    adj = torch.sparse.FloatTensor(edge_index, edge_weight, torch.Size([num_nodes,num_nodes]))\n",
    "    return adj\n",
    "from sklearn.cluster import kmeans_plusplus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import scipy.sparse as sp\n",
    "from sklearn.cluster import KMeans\n",
    "from torch_geometric.utils import from_scipy_sparse_matrix\n",
    "\n",
    "def adj2edge(adj_sp):\n",
    "    \"\"\"\n",
    "    Convert a scipy sparse adjacency matrix to a PyG edge_index.\n",
    "    \"\"\"\n",
    "    return from_scipy_sparse_matrix(adj_sp)\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import scipy.sparse as sp\n",
    "from sklearn.cluster import KMeans\n",
    "from torch_geometric.utils import from_scipy_sparse_matrix\n",
    "\n",
    "def adj2edge(adj_sp):\n",
    "    \"\"\"Convert a scipy sparse matrix to PyG edge_index.\"\"\"\n",
    "    return from_scipy_sparse_matrix(adj_sp)\n",
    "\n",
    "def Make_tree_real1(\n",
    "    X, A, gnn_model,\n",
    "    levels, ratio=0.2,\n",
    "    temp=0.1, tau=0.5\n",
    "):\n",
    "    \"\"\"\n",
    "    Hierarchical graph coarsening with margin‐score.\n",
    "\n",
    "    Args:\n",
    "        X: np.ndarray or torch.Tensor, shape (N, D)\n",
    "        A: scipy.sparse adjacency, shape (N, N)\n",
    "        gnn_model: torch.nn.Module mapping (x, edge_index) -> node_scores\n",
    "        levels: int, total levels including original\n",
    "        ratio: float, cluster‐count ratio per level\n",
    "        temp: float, softmax temperature\n",
    "        tau: float, threshold for margin‐score masks\n",
    "\n",
    "    Returns:\n",
    "        treeG: list of dicts with 'IDX','clusters','adj','features'\n",
    "        S_assign_list: list of soft‐assignment matrices\n",
    "    \"\"\"\n",
    "    # Ensure features as numpy\n",
    "    if isinstance(X, torch.Tensor):\n",
    "        X = X.cpu().numpy()\n",
    "\n",
    "    adj_list = [A]\n",
    "    features_list = [X]\n",
    "    parents = []\n",
    "    S_assign_list = []\n",
    "    N_start = A.shape[0]\n",
    "\n",
    "    for level in range(levels - 1):\n",
    "        print(f\"--- Coarsening Level {level + 1} ---\")\n",
    "\n",
    "        # 1) edge_index for GNN\n",
    "        edge_index, _ = adj2edge(sp.coo_matrix(A))\n",
    "\n",
    "        # 2) node_scores via GNN\n",
    "        with torch.no_grad():\n",
    "            z = gnn_model(torch.tensor(X, dtype=torch.float32), edge_index)\n",
    "        node_scores = z.cpu().numpy()        # shape (N_cur, D)\n",
    "        N_cur = node_scores.shape[0]\n",
    "\n",
    "        # 3) margin‐score: (pos_sum − neg_sum)/(pos_count + neg_count)\n",
    "        pos_mask = (node_scores > tau).astype(np.float32)\n",
    "        neg_mask = 1.0 - pos_mask\n",
    "        pos_sum = (node_scores * pos_mask).sum(axis=1)\n",
    "        neg_sum = (node_scores * neg_mask).sum(axis=1)\n",
    "        counts = pos_mask.sum(axis=1) + neg_mask.sum(axis=1) + 1e-6\n",
    "        margin_score = (pos_sum - neg_sum) / counts    # shape (N_cur,)\n",
    "\n",
    "        # 4) decide cluster count\n",
    "        if N_cur <= 2 or level == levels - 2:\n",
    "            K = 1\n",
    "        else:\n",
    "            K = int(N_cur * ratio) + 1\n",
    "        print(f\"num_clusters = {K}, nodes = {N_cur}\")\n",
    "\n",
    "        # 5) fit real KMeans\n",
    "        kmeans = KMeans(n_clusters=K, random_state=42).fit(node_scores)\n",
    "        centers = kmeans.cluster_centers_  # (K, D)\n",
    "        hard_labels = kmeans.labels_       # (N_cur,)\n",
    "\n",
    "        # 6) squared distances (N_cur, K)\n",
    "        dists = np.sum(\n",
    "            (node_scores[:, None, :] - centers[None, :, :])**2,\n",
    "            axis=2\n",
    "        )\n",
    "\n",
    "        # 7) negative logits scaled by margin_score\n",
    "        #    logits[i,k] = -dists[i,k]/temp * margin_score[i]\n",
    "        logits = (-dists / temp) * margin_score[:, None]\n",
    "\n",
    "        # 8) softmax row‐wise\n",
    "        m = logits.max(axis=1, keepdims=True)\n",
    "        exp_l = np.exp(logits - m)\n",
    "        S_assign = exp_l / exp_l.sum(axis=1, keepdims=True)\n",
    "\n",
    "        # 9) coarsen adjacency via labels\n",
    "        rr, cc, vv = sp.find(A)\n",
    "        order = np.argsort(rr)\n",
    "        rr, cc, vv = rr[order], cc[order], vv[order]\n",
    "        nrr, ncc = hard_labels[rr], hard_labels[cc]\n",
    "        A = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))\n",
    "        adj_list.append(A)\n",
    "\n",
    "        # 10) coarsen features: Sᵀ X\n",
    "        X = S_assign.T.dot(X)\n",
    "        features_list.append(X)\n",
    "\n",
    "        # 11) record parents & soft‐assign\n",
    "        parents.append(hard_labels)\n",
    "        S_assign_list.append(S_assign)\n",
    "\n",
    "    # 12) build output treeG\n",
    "    treeG = [None] * levels\n",
    "    for lvl in range(levels):\n",
    "        if lvl == 0:\n",
    "            idxs = np.arange(N_start)\n",
    "            clusters = [[i] for i in idxs]\n",
    "        else:\n",
    "            pid = parents[lvl - 1]\n",
    "            order = np.argsort(pid)\n",
    "            vals, idx0 = np.unique(pid[order], return_index=True)\n",
    "            clusters = np.split(order, idx0[1:])\n",
    "        treeG[lvl] = {\n",
    "            'IDX': (np.arange(N_start) if lvl == 0 else parents[lvl - 1]),\n",
    "            'clusters': clusters,\n",
    "            'adj': adj_list[lvl],\n",
    "            'features': features_list[lvl]\n",
    "        }\n",
    "\n",
    "    return treeG, S_assign_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "def HaarGOB_with_Sassign(treeG, S_assign_list):\n",
    "    \"\"\"\n",
    "    \n",
    "    \n",
    "    INPUT:\n",
    "        treeG: Tree structure from Coarsen_tree.\n",
    "        S_assign_list: List of soft assignment matrices from each level.\n",
    "    \n",
    "    OUTPUT:\n",
    "        treeG: Updated with Haar basis vectors at each level.\n",
    "    \"\"\"\n",
    "    Ntr = len(treeG)\n",
    "    \n",
    "    # Step 1: Global basis at top level\n",
    "    clusterJ0 = treeG[Ntr-1]['clusters']\n",
    "    N0 = len(clusterJ0)\n",
    "    chic = np.identity(N0)\n",
    "    uc = [None] * N0\n",
    "    uc[0] = 1 / np.sqrt(N0) * np.ones(N0)\n",
    "\n",
    "    for l in range(1, N0):\n",
    "        uc[l] = np.sqrt((N0 - l) / (N0 - l + 1)) * (\n",
    "            chic[l-1, :] - 1/(N0 - l) * np.sum(chic[l:, :], axis=0)\n",
    "        )\n",
    "    treeG[Ntr-1]['u'] = uc\n",
    "\n",
    "    # Step 2: Propagate basis using S_assign\n",
    "    for j_tr in np.arange(Ntr-2, -1, -1):\n",
    "        N1 = len(treeG[j_tr]['clusters'])\n",
    "        u = [None] * N1\n",
    "        i = N0  # Index for extended basis vectors\n",
    "\n",
    "        # Use S_assign from current level\n",
    "        S_assign = S_assign_list[j_tr]\n",
    "\n",
    "        for l in range(N0):\n",
    "            clusterl = treeG[j_tr+1]['clusters'][l]\n",
    "            kl = len(clusterl)\n",
    "            ucl = uc[l]\n",
    "\n",
    "            # Step 3: Weighted propagation using soft assignment\n",
    "            ul1 = np.zeros(N1)\n",
    "            for j in range(N0):\n",
    "                idxj = treeG[j_tr+1]['clusters'][j]\n",
    "                cluster_weights = S_assign[idxj, l]  # Use S_assign for weighting\n",
    "                ul1[idxj] = ucl[j] * (cluster_weights)  # Weighted by class similarity\n",
    "            u[l] = ul1 / np.sqrt(kl)  # Normalize\n",
    "\n",
    "            # Step 4: Localized differences within clusters\n",
    "            if kl > 1:\n",
    "                chil = np.zeros((kl, N1))\n",
    "                for k in range(kl):\n",
    "                    idxl = treeG[j_tr+1]['clusters'][l]\n",
    "                    chil[k, idxl[k]] = 1\n",
    "\n",
    "                for k in range(1, kl):\n",
    "                    i += 1\n",
    "                    ulk = np.sqrt((kl - k) / (kl - k + 1)) * (\n",
    "                        chil[k-1, :] - 1/(kl - k) * np.sum(chil[k:, :], axis=0)\n",
    "                    )\n",
    "                    u[i-1] = ulk\n",
    "\n",
    "        # Update basis for next level\n",
    "        treeG[j_tr]['u'] = u\n",
    "        uc = u\n",
    "        N0 = N1\n",
    "\n",
    "    return treeG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "#vanilla GNN\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_haar_basis_and_graph_info(tree_real):\n",
    "    \"\"\"\n",
    "    Extract Haar basis matrices, edge indices, and node/edge counts for each level from tree_real.\n",
    "    Returns:\n",
    "        U: list of Haar basis matrices for each level\n",
    "        num_nodes_tree: array of number of nodes per level\n",
    "        num_edges_tree: array of number of edges per level\n",
    "        edge_index_list: list of edge_index tensors per level\n",
    "    \"\"\"\n",
    "    Tree_length = len(tree_real)\n",
    "    num_nodes_tree = np.zeros(Tree_length, dtype=int)\n",
    "    num_edges_tree = np.zeros(Tree_length, dtype=int)\n",
    "    edge_index_list = [None] * Tree_length\n",
    "    U = []  # Haar basis for each level\n",
    "    features_list = []\n",
    "    for j in range(Tree_length):\n",
    "        u = tree_real[j]['u']\n",
    "        N = len(u)\n",
    "        # Next level's basis size, or 1 for last level\n",
    "        N1 = len(tree_real[j+1]['u']) if j < Tree_length - 1 else 1\n",
    "        HaarBases = np.zeros((N, N1), dtype=np.float64)\n",
    "        for k in range(N1):\n",
    "            HaarBases[:, k] = u[k]\n",
    "        U.append(HaarBases)\n",
    "        num_nodes_tree[j] = N\n",
    "        edge_index, _ = adj2edge(tree_real[j]['adj'])\n",
    "        edge_index_list[j] = edge_index\n",
    "        num_edges_tree[j] = len(edge_index[0])\n",
    "        features_list.append(tree_real[j]['features'])\n",
    "\n",
    "    num_nodes_tree[-1] = 1\n",
    "    num_edges_tree[-1] = 1\n",
    "    return U, num_nodes_tree, num_edges_tree, edge_index_list, features_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "\n",
    "from sklearn.cluster import KMeans\n",
    "from torch_geometric.datasets import TUDataset\n",
    "from torch_geometric.data import DataLoader, Data\n",
    "from torch_geometric.utils import to_scipy_sparse_matrix, subgraph, from_scipy_sparse_matrix\n",
    "from torch_geometric.nn import GCNConv, global_mean_pool"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Node classification dataset (full graph per batch) ---\n",
    "import os, torch\n",
    "from torch_geometric.datasets import WebKB   # or: from torch_geometric.datasets import Planetoid\n",
    "from torch_geometric.loader import DataLoader\n",
    "\n",
    "name = 'Texas'  # choices: 'Texas', 'Cornell', 'Wisconsin'  (or use Actor/Planetoid)\n",
    "root = os.path.join(os.path.abspath(''), 'data', name)\n",
    "dataset = WebKB(root, name)                 # for Planetoid: dataset = Planetoid(root, 'Cora')\n",
    "\n",
    "data = dataset[0]\n",
    "# Fallback if dataset has no features:\n",
    "if data.x is None or data.num_features == 0:\n",
    "    data.x = torch.ones(data.num_nodes, 1)\n",
    "\n",
    "# Full-graph loaders (batch_size=1)\n",
    "train_loader = DataLoader([data], batch_size=1, shuffle=False)\n",
    "val_loader   = DataLoader([data], batch_size=1, shuffle=False)\n",
    "test_loader  = DataLoader([data], batch_size=1, shuffle=False)\n",
    "\n",
    "# Init a GNN encoder for coarsening (input dim = dataset.num_features)\n",
    "in_dim  = dataset.num_features if dataset.num_features > 0 else 1\n",
    "hid_dim = 16\n",
    "out_dim = 32\n",
    "#gnn_model = GNNModel(in_dim, hid_dim, out_dim).eval()   # use eval() during coarsen\n",
    "\n",
    "\n",
    "\n",
    "\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_batch_to_graphs(batch):\n",
    "    X_list, edge_index_list,y_list = [], [],[]\n",
    "    graph_ids = batch.batch.unique(sorted=True)\n",
    "    for gid in graph_ids:\n",
    "        node_idx = (batch.batch == gid).nonzero(as_tuple=False).view(-1)\n",
    "        ei_i, _ = subgraph(node_idx, batch.edge_index, relabel_nodes=True)\n",
    "        X_i = batch.x[node_idx]\n",
    "        X_list.append(X_i)\n",
    "        edge_index_list.append(ei_i)\n",
    "        y_list.append(batch.y[gid].view(()))  \n",
    "    return X_list, edge_index_list, y_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Make_tree_real2(\n",
    "    X, edge_index, gnn_model,\n",
    "    levels, ratio=0.2,\n",
    "    temp=0.1, tau=0.5\n",
    "):\n",
    "    \"\"\"\n",
    "    Hierarchical graph coarsening with margin‐score.\n",
    "\n",
    "    Args:\n",
    "        X: np.ndarray or torch.Tensor, shape (N, D)\n",
    "        A: scipy.sparse adjacency, shape (N, N)\n",
    "        gnn_model: torch.nn.Module mapping (x, edge_index) -> node_scores\n",
    "        levels: int, total levels including original\n",
    "        ratio: float, cluster‐count ratio per level\n",
    "        temp: float, softmax temperature\n",
    "        tau: float, threshold for margin‐score masks\n",
    "\n",
    "    Returns:\n",
    "        treeG: list of dicts with 'IDX','clusters','adj','features'\n",
    "        S_assign_list: list of soft‐assignment matrices\n",
    "    \"\"\"\n",
    "    # Ensure features as numpy\n",
    "    if isinstance(X, torch.Tensor):\n",
    "        X = X.cpu().numpy()\n",
    "    A= to_scipy_sparse_matrix(edge_index, num_nodes=X.shape[0])\n",
    "    adj_list = [A]\n",
    "    features_list = [X]\n",
    "    parents = []\n",
    "    S_assign_list = []\n",
    "    N_start = X.shape[0]\n",
    "\n",
    "    for level in range(levels - 1):\n",
    "        print(f\"--- Coarsening Level {level + 1} ---\")\n",
    "\n",
    "        # 1) edge_index for GNN\n",
    "        #edge_index, _ = adj2edge(sp.coo_matrix(A))\n",
    "\n",
    "        # 2) node_scores via GNN\n",
    "        with torch.no_grad():\n",
    "            z = gnn_model(torch.tensor(X, dtype=torch.float32), edge_index)\n",
    "        # sanitize encoder outputs\n",
    "        z = np.nan_to_num(z, nan=0.0, posinf=1e6, neginf=-1e6)\n",
    "        # standardize per feature, then clip\n",
    "        z = (z - z.mean(axis=0, keepdims=True)) / (z.std(axis=0, keepdims=True) + 1e-6)\n",
    "        z = np.clip(z, -8.0, 8.0)\n",
    "        node_scores = z.cpu().numpy()        # shape (N_cur, D)\n",
    "        N_cur = node_scores.shape[0]\n",
    "\n",
    "        # 3) margin‐score: (pos_sum − neg_sum)/(pos_count + neg_count)\n",
    "        pos_mask = (node_scores > tau).astype(np.float32)\n",
    "        neg_mask = 1.0 - pos_mask\n",
    "        pos_sum = (node_scores * pos_mask).sum(axis=1)\n",
    "        neg_sum = (node_scores * neg_mask).sum(axis=1)\n",
    "        counts = pos_mask.sum(axis=1) + neg_mask.sum(axis=1) + 1e-6\n",
    "        margin_score = (pos_sum - neg_sum) / counts    # shape (N_cur,)\n",
    "\n",
    "        # 4) decide cluster count\n",
    "        if N_cur <= 2 or level == levels - 2:\n",
    "            K = 1\n",
    "        else:\n",
    "            K = int(N_cur * ratio) + 1\n",
    "       # print(f\"num_clusters = {K}, nodes = {N_cur}\")\n",
    "\n",
    "        # 5) fit real KMeans\n",
    "        kmeans = KMeans(n_clusters=K, random_state=42).fit(node_scores)\n",
    "        centers = kmeans.cluster_centers_  # (K, D)\n",
    "        hard_labels = kmeans.labels_       # (N_cur,)\n",
    "\n",
    "        # 6) squared distances (N_cur, K)\n",
    "        dists = np.sum(\n",
    "            (node_scores[:, None, :] - centers[None, :, :])**2,\n",
    "            axis=2\n",
    "        )\n",
    "\n",
    "        # 7) negative logits scaled by margin_score\n",
    "        #    logits[i,k] = -dists[i,k]/temp * margin_score[i]\n",
    "        logits = (-dists / temp) * margin_score[:, None]\n",
    "\n",
    "        # 8) softmax row‐wise\n",
    "        m = logits.max(axis=1, keepdims=True)\n",
    "        exp_l = np.exp(logits - m)\n",
    "        S_assign = exp_l / exp_l.sum(axis=1, keepdims=True)\n",
    "\n",
    "        # 9) coarsen adjacency via hard_labels\n",
    "        rr, cc, vv = sp.find(A)\n",
    "        order = np.argsort(rr)\n",
    "        rr, cc, vv = rr[order], cc[order], vv[order]\n",
    "        nrr, ncc = hard_labels[rr], hard_labels[cc]\n",
    "        A = sp.csr_matrix((vv, (nrr, ncc)), shape=(K, K))\n",
    "        adj_list.append(A)\n",
    "\n",
    "        # 10) coarsen features: Sᵀ X\n",
    "        X = S_assign.T.dot(X)\n",
    "        features_list.append(X)\n",
    "\n",
    "        # 11) record parents & soft‐assign\n",
    "        parents.append(hard_labels)\n",
    "        S_assign_list.append(S_assign)\n",
    "\n",
    "    # 12) build output treeG\n",
    "    treeG = [None] * levels\n",
    "    for lvl in range(levels):\n",
    "        if lvl == 0:\n",
    "            idxs = np.arange(N_start)\n",
    "            clusters = [[i] for i in idxs]\n",
    "        else:\n",
    "            pid = parents[lvl - 1]\n",
    "            order = np.argsort(pid)\n",
    "            vals, idx0 = np.unique(pid[order], return_index=True)\n",
    "            clusters = np.split(order, idx0[1:])\n",
    "        treeG[lvl] = {\n",
    "            'IDX': (np.arange(N_start) if lvl == 0 else parents[lvl - 1]),\n",
    "            'clusters': clusters,\n",
    "            'adj': adj_list[lvl],\n",
    "            'features': features_list[lvl]\n",
    "        }\n",
    "\n",
    "    return treeG, S_assign_list\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def Uext_batch_from_tree_lists(\n",
    "    X_list, edge_index_list, gnn_model,\n",
    "    levels=5, ratio=0.3, temp=0.1, tau=0.5\n",
    "):\n",
    "    U_batch = []                     # (optional placeholder if you later add Haar)\n",
    "    edge_index_list_batch = []\n",
    "    num_nodes_tree_batch  = []\n",
    "    num_edges_tree_batch  = []\n",
    "    features_list_batch   = []\n",
    "    treeG_batch=[]\n",
    "    S_assign_List = []\n",
    "\n",
    "    for X_i, ei_i in zip(X_list, edge_index_list):\n",
    "        Adjacency = to_scipy_sparse_matrix(ei_i, num_nodes=X_i.shape[0])\n",
    "        treeG_i, S_assign_list = Make_tree_real1(\n",
    "            X_i, Adjacency, gnn_model,\n",
    "            levels=levels, ratio=ratio, temp=temp, tau=tau\n",
    "        )\n",
    "        treeG_i = HaarGOB_with_Sassign(treeG_i, S_assign_list)\n",
    "        U_i, n_nodes_i, n_edges_i, eidx_i,feats_i = extract_haar_basis_and_graph_info(treeG_i)\n",
    "\n",
    "        U_batch.append(U_i)                     # (or your Haar basis if you compute it)\n",
    "        edge_index_list_batch.append(eidx_i)\n",
    "        num_nodes_tree_batch.append(n_nodes_i)\n",
    "        num_edges_tree_batch.append(n_edges_i)\n",
    "        features_list_batch.append(feats_i)\n",
    "        treeG_batch.append(treeG_i)\n",
    "        S_assign_List.append(S_assign_list)\n",
    "\n",
    "    return U_batch, edge_index_list_batch, num_nodes_tree_batch, num_edges_tree_batch, features_list_batch,treeG_batch, S_assign_List\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "in_dim  = dataset.num_features if dataset.num_features > 0 else 1\n",
    "hid_dim = 32\n",
    "out_dim = 16\n",
    "#gnn_model = GNNModel(in_dim, hid_dim, out_dim).eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SparseHaarGNN_old(nn.Module):\n",
    "    def __init__(self, input_dim, hidden_dim, output_dim,max_K,levels=None):\n",
    "        \"\"\"\n",
    "        U_list: List of Haar basis matrices (sparse numpy arrays)\n",
    "        \"\"\"\n",
    "        super().__init__()\n",
    "    \n",
    "        \n",
    "        self.levels = levels \n",
    "        \n",
    "        # Shared filter across levels (diagonal matrix)\n",
    "        max_basis_size = max_K\n",
    "        self.Lambda = nn.Parameter(torch.randn(max_basis_size))  # Shared across levels\n",
    "        \n",
    "        # MLP for classification\n",
    "        self.mlp = nn.Linear(input_dim, hidden_dim)\n",
    "        self.classifier = nn.Linear(hidden_dim , output_dim)\n",
    "        \n",
    "    def forward(self, X_list,U_list,treeG):\n",
    "        \"\"\"\n",
    "        X_list: Features at each level [X0, X1, ..., XL]\n",
    "        \"\"\"\n",
    "        # Convert features to torch\n",
    "        X_list = [torch.tensor(X, dtype=torch.float32) for X in X_list]\n",
    "        # Convert U_list to sparse tensors\n",
    "        self.U_list = [\n",
    "            torch.tensor(U, dtype=torch.float32)\n",
    "            for U in U_list  # Skip final level\n",
    "        ]\n",
    "        # Process each level\n",
    "        H_list = []\n",
    "        levels = len(self.U_list)\n",
    "        for l in range(levels):\n",
    "            X = X_list[l]\n",
    "            U = self.U_list[l]\n",
    "            print(f\"Level {l}: U shape {U}, X shape {X}\")\n",
    "            # Spectral convolution\n",
    "            X_hat = torch.sparse.mm(U.t(), X)  # Project to spectral domain\n",
    "            \n",
    "            Lambda_l = self.Lambda[:X_hat.shape[0]]  # Slice shared filter\n",
    "            Lambda_l = torch.diag(Lambda_l)\n",
    "            print(f\"Lambda_l shape: {Lambda_l.shape}\")\n",
    "            print(f\"X_hat shape: {X_hat.shape}\")\n",
    "            print(f\"U shape: {U.shape}\")\n",
    "            # Apply filter and reconstruct\n",
    "            H = torch.sparse.mm(U, torch.matmul(Lambda_l, X_hat))\n",
    "            H = F.relu(H)\n",
    "            H_list.append(H)\n",
    "\n",
    "        # Aggregate multi-level features\n",
    "        Number_of_levels= len(treeG)  # Number of levels in the tree\n",
    "        for level in range(Number_of_levels-1,0,-1): # output: 4,3,2,1\n",
    "            print(f\"Aggregating features at level {level}\")\n",
    "            finer_features=treeG[0]['features']  # Features at the finer level\n",
    "            finer_features = torch.tensor(finer_features, dtype=torch.float32)\n",
    "            finer_features=self.mlp(finer_features)  # Apply MLP to the features\n",
    "            current_features= H_list[level]  # Features at the current level\n",
    "            clusters = treeG[level]['clusters']  # Clusters at the current level\n",
    "            \n",
    "            for i, cluster in enumerate(clusters):\n",
    "                finer_features[cluster, :] = finer_features[cluster, :] + current_features[i, :]  # Aggregate features from clusters\n",
    "            # Update the current features for the next iteration\n",
    "            \n",
    "\n",
    "        # Combine via concatenation (or average/max)\n",
    "    \n",
    "        #final_features = torch.cat(H_pooled, dim=1)\n",
    "        \n",
    "        # Final classification\n",
    "        \n",
    "        logits = self.classifier(finer_features)\n",
    "        return logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import scipy.sparse as sp\n",
    "\n",
    "# ---- utilities ----\n",
    "def _to_dense_torch(mat, device):\n",
    "    \"\"\"numpy or scipy.spmatrix -> torch.float32 on device\"\"\"\n",
    "    if isinstance(mat, np.ndarray):\n",
    "        arr = mat\n",
    "    elif sp.issparse(mat):\n",
    "        arr = mat.toarray()\n",
    "    else:\n",
    "        arr = np.asarray(mat)\n",
    "    return torch.as_tensor(arr, dtype=torch.float32, device=device)\n",
    "\n",
    "def unpool_one_level(H_coarse, clusters, N_fine):\n",
    "    \"\"\"\n",
    "    Scatter coarser features H_coarse [N_coarse, D] to finer level of size N_fine\n",
    "    using 'clusters' (list of child index arrays for each coarse node).\n",
    "    Returns H_fine [N_fine, D].\n",
    "    \"\"\"\n",
    "    device = H_coarse.device\n",
    "    D = H_coarse.size(1)\n",
    "    H_fine = torch.zeros(N_fine, D, device=device)\n",
    "    # clusters length == N_coarse; clusters[i] are indices at the finer level\n",
    "    for i, child_idx in enumerate(clusters):\n",
    "        if len(child_idx) == 0:\n",
    "            continue\n",
    "        idx = torch.as_tensor(child_idx, dtype=torch.long, device=device)\n",
    "        H_fine.index_add_(0, idx, H_coarse[i].expand(idx.numel(), D))\n",
    "    return H_fine\n",
    "\n",
    "def unpool_to_level0(H_l, level_l, treeG):\n",
    "    \"\"\"\n",
    "    Recursively unpool H_l from level 'level_l' down to level 0 using treeG[level]['clusters'].\n",
    "    treeG[level]['clusters'] is a list where element i holds the child indices at level-1.\n",
    "    \"\"\"\n",
    "    H = H_l\n",
    "    for m in range(level_l, 0, -1):\n",
    "        clusters_m = treeG[m]['clusters']           # children at level m-1\n",
    "        N_fine     = treeG[m-1]['adj'].shape[0]\n",
    "        H = unpool_one_level(H, clusters_m, N_fine) # now at level m-1\n",
    "    return H  # now at level 0\n",
    "\n",
    "# ---- spectral block: U @ (diag(lambda) @ (U^T X)) ----\n",
    "class HaarSpectralBlock(nn.Module):\n",
    "    def __init__(self, max_K: int):\n",
    "        super().__init__()\n",
    "        self.lambda_vec = nn.Parameter(torch.randn(max_K))\n",
    "\n",
    "    def forward(self, U: torch.Tensor, X: torch.Tensor):\n",
    "        # U: [N_l, K_l] (dense), X: [N_l, F] -> H: [N_l, F]\n",
    "        X_hat = U.transpose(0, 1) @ X               # [K_l, F]\n",
    "        K_l = X_hat.size(0)\n",
    "        lam = self.lambda_vec[:K_l].unsqueeze(1)    # [K_l, 1]\n",
    "        X_hat = X_hat * lam\n",
    "        H = U @ X_hat                               \n",
    "        return F.relu(H)\n",
    "\n",
    "# ---- node classifier that aggregates all levels at level 0 ----\n",
    "class NodeHaarUnpoolClassifier(nn.Module):\n",
    "    \"\"\"\n",
    "    For one graph:\n",
    "      - Applies a shared spectral block per level.\n",
    "      - Unpools each level’s features to level 0 using treeG[level]['clusters'].\n",
    "      - Concatenates per-level contributions at level 0 and classifies nodes.\n",
    "    \"\"\"\n",
    "    def __init__(self, in_dim: int, hid_dim: int, num_classes: int, max_K: int, num_levels: int):\n",
    "        super().__init__()\n",
    "        self.num_levels = num_levels    # how many levels to use (typically L-1; skip last 1-node level)\n",
    "        self.pre = nn.Linear(in_dim, hid_dim)\n",
    "        self.block = HaarSpectralBlock(max_K=max_K)\n",
    "        self.classifier = nn.Linear(hid_dim * num_levels, num_classes)\n",
    "        self.dropout = nn.Dropout(p=0.3)\n",
    "\n",
    "    def forward(self, U_list, features_list, treeG):\n",
    "        \"\"\"\n",
    "        U_list:        list of [N_l, K_l] (numpy/scipy or torch), usually levels 0..L-2\n",
    "        features_list: list of [N_l, Fin]  (same levels)\n",
    "        treeG:         list of dicts with 'clusters' and 'adj' for levels 0..L-1\n",
    "        Returns: logits over nodes at level 0, shape [N0, num_classes]\n",
    "        \"\"\"\n",
    "        device = next(self.parameters()).device\n",
    "        L_eff = min(self.num_levels, len(U_list))   # safety\n",
    "\n",
    "        # Preproject features at each level to hidden, run spectral block\n",
    "        H_per_level = []\n",
    "        for l in range(L_eff):\n",
    "            X_l = _to_dense_torch(features_list[l], device)   # [N_l, Fin]\n",
    "            X_l = self.dropout(F.relu(self.pre(X_l)))         # [N_l, H]\n",
    "            U_l = _to_dense_torch(U_list[l], device)          # [N_l, K_l]\n",
    "            H_l = self.block(U_l, X_l)                        # [N_l, H]\n",
    "            # Unpool to level 0\n",
    "            H0_l = unpool_to_level0(H_l, level_l=l, treeG=treeG)  # [N0, H]\n",
    "            H_per_level.append(H0_l)\n",
    "\n",
    "        # Concatenate per-level contributions at level 0\n",
    "        H0_cat = torch.cat(H_per_level, dim=1)      # [N0, H * L_eff]\n",
    "        H0_cat = self.dropout(H0_cat)\n",
    "        logits = self.classifier(H0_cat)            # [N0, C]\n",
    "        return logits\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [],
   "source": [
    "levels=4\n",
    "in_dim  = dataset.num_features if dataset.num_features > 0 else 1\n",
    "hid_dim = 64\n",
    "out_dim = dataset.num_classes\n",
    "max_K   = 64                                # large enough upper bound on K_l\n",
    "net = SparseHaarGNN_old(in_dim, hid_dim, out_dim, max_K=max_K, levels=levels-1)  # skip last degenerate level\n",
    "net.train()\n",
    "\n",
    "opt = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=5e-4)\n",
    "\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "model2= NodeHaarUnpoolClassifier(\n",
    "    in_dim=in_dim, hid_dim=hid_dim, num_classes=out_dim, max_K=max_K, num_levels=4-1\n",
    ")\n",
    "model=SparseHaarGNN_old(in_dim, hid_dim, out_dim,max_K=max_K, levels=4-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "\n",
    "from torch_sparse import spspmm, coalesce\n",
    "from torch_sparse import SparseTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch_sparse import SparseTensor\n",
    "\n",
    "class Hetero_Graph_Attention_Layer(nn.Module):\n",
    "    def __init__(self, in_features, out_features, dropout=0.1, alpha=0.2, num_layers=1):\n",
    "        super(Hetero_Graph_Attention_Layer, self).__init__()\n",
    "        self.dropout = dropout\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.alpha = alpha\n",
    "        self.num_layers = num_layers\n",
    "\n",
    "        # Learnable weight matrix for node embeddings\n",
    "        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))\n",
    "        nn.init.xavier_uniform_(self.W.data, gain=1.414)\n",
    "\n",
    "        # Multi-layer perceptron for attention weights\n",
    "        self.attention_mlp = nn.Sequential(\n",
    "            nn.Linear(2 * out_features, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(16, 1)\n",
    "        )\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.xavier_uniform_(self.W.data, gain=1.414)\n",
    "        for layer in self.attention_mlp:\n",
    "            if isinstance(layer, nn.Linear):\n",
    "                nn.init.xavier_uniform_(layer.weight, gain=1.414)\n",
    "                if layer.bias is not None:\n",
    "                    nn.init.constant_(layer.bias, 0)\n",
    "\n",
    "    def forward(self, h, edge_index):\n",
    "        \"\"\"\n",
    "        h: [num_nodes, in_features] (NO batch)\n",
    "        edge_index: [2, E]\n",
    "        \"\"\"\n",
    "        # Project features\n",
    "        Wh = torch.matmul(h, self.W)  # [N, Fout]\n",
    "\n",
    "        # Build sparse adjacency from edges\n",
    "        num_nodes = Wh.size(0)\n",
    "        adj = self._edge_index_to_adj(edge_index, num_nodes)\n",
    "        print(\"adj\", adj)\n",
    "\n",
    "        # (i) Feature attention on edges -> sparse matrix\n",
    "        similarity_scores = self._compute_attention_scores(Wh, edge_index)  # sparse [N,N]\n",
    "        similarity_softmax = self._sparse_softmax(similarity_scores)\n",
    "\n",
    "        # (ii) Structure similarity (Jaccard) on edges -> sparse matrix\n",
    "        structure_similarity = self.compute_structure_similarity_scores(edge_index, num_nodes)  # sparse [N,N]\n",
    "        structure_similarity_softmax = self._sparse_softmax(structure_similarity)\n",
    "\n",
    "        # Combine\n",
    "        attention_scores = similarity_softmax + structure_similarity_softmax  # sparse [N,N]\n",
    "        print(\"similarity_softmax\", similarity_softmax)\n",
    "        print(\"structure_similarity_softmax\", structure_similarity_softmax)\n",
    "        print(\"attention_scores\", attention_scores)\n",
    "\n",
    "        # Adaptive adjacency (convert to dense for formula, then back to sparse for mm)\n",
    "        adaptive_adj = self._compute_adaptive_adj(adj, attention_scores.to_dense())  # sparse [N,N]\n",
    "\n",
    "        # Message passing (same loop shape as your code)\n",
    "        h_prime = Wh\n",
    "        for _ in range(self.num_layers):\n",
    "            h_prime = torch.sparse.mm(adaptive_adj, h_prime)  # [N, Fout]\n",
    "            h_prime = F.dropout(h_prime, p=self.dropout, training=self.training)\n",
    "\n",
    "        return h_prime  # [N, Fout]\n",
    "\n",
    "    def _edge_index_to_adj(self, edge_index, num_nodes):\n",
    "        values = torch.ones(edge_index.size(1), device=edge_index.device)\n",
    "        adj = torch.sparse_coo_tensor(edge_index, values, (num_nodes, num_nodes))\n",
    "        return adj.coalesce()  # sorts & sums duplicates\n",
    "\n",
    "    def _compute_adaptive_adj(self, adj, S_class):\n",
    "        \"\"\"\n",
    "        adj: sparse [N,N]; S_class: dense [N,N] in [0,1]\n",
    "        hetero_adj = S_class * A + (1 - S_class) * (I - A)\n",
    "        Return sparse for sparse mm.\n",
    "        \"\"\"\n",
    "        N = adj.size(0)\n",
    "        A = adj.to_dense()                                # [N,N]\n",
    "        I = torch.eye(N, device=A.device)                 # [N,N]\n",
    "        hetero = S_class * A + (1.0 - S_class) * (I - A)  # [N,N] dense\n",
    "        return hetero.to_sparse().coalesce()              # sparse [N,N]\n",
    "\n",
    "    def _compute_attention_scores(self, Wh, edge_index):\n",
    "        \"\"\"\n",
    "        Wh: [N, Fout]; edge_index: [2, E]\n",
    "        Return sparse scores with values only on given edges.\n",
    "        \"\"\"\n",
    "        src_idx, dst_idx = edge_index[0], edge_index[1]     # [E]\n",
    "        src = Wh[src_idx, :]                                 # [E, Fout]\n",
    "        dst = Wh[dst_idx, :]                                 # [E, Fout]\n",
    "        attention_input = torch.cat([src, dst], dim=-1)      # [E, 2*Fout]\n",
    "\n",
    "        attention_scores = self.attention_mlp(attention_input).squeeze(-1)  # [E]\n",
    "        print(\"attention_scores\", attention_scores.shape)\n",
    "\n",
    "        num_nodes = Wh.size(0)\n",
    "        attention_scores_sparse = torch.sparse_coo_tensor(\n",
    "            edge_index, attention_scores, (num_nodes, num_nodes)\n",
    "        ).coalesce()\n",
    "        return attention_scores_sparse\n",
    "\n",
    "    def _sparse_softmax(self, scores_sparse):\n",
    "        \"\"\"\n",
    "        Row-wise softmax over neighbors (convert to dense for simplicity).\n",
    "        \"\"\"\n",
    "        dense = scores_sparse.to_dense() if scores_sparse.is_sparse else scores_sparse\n",
    "        dense_softmax = torch.softmax(dense, dim=-1)\n",
    "        return dense_softmax.to_sparse().coalesce()\n",
    "\n",
    "    def compute_structure_similarity_scores(self, edge_index, num_nodes):\n",
    "        \"\"\"\n",
    "        Jaccard similarity per edge, returned as a sparse matrix on those edges.\n",
    "        \"\"\"\n",
    "        row, col = edge_index[0].long(), edge_index[1].long()\n",
    "\n",
    "        # Sparse adjacency (torch_sparse)\n",
    "        adj = SparseTensor(row=row, col=col, sparse_sizes=(num_nodes, num_nodes))\n",
    "\n",
    "        # Common neighbors and degree\n",
    "        # NOTE: convert to dense BEFORE advanced indexing to get 1-D picks.\n",
    "        common_dense = (adj @ adj.t()).to_dense()   # [N, N]\n",
    "        degree = adj.sum(dim=1).to_dense()          # [N]\n",
    "\n",
    "        # Element-wise picks for each edge (u=row[i], v=col[i]) -> [E]\n",
    "        cn_edge = common_dense[row, col]            # [E]\n",
    "        deg_row = degree[row]                       # [E]\n",
    "        deg_col = degree[col]                       # [E]\n",
    "\n",
    "        total = deg_row + deg_col - cn_edge\n",
    "        jacc = (cn_edge / total.clamp_min(1e-9)).contiguous().view(-1)  # ensure 1-D [E]\n",
    "\n",
    "        # Build sparse scores on the given edges\n",
    "        jacc_sparse = torch.sparse_coo_tensor(\n",
    "            torch.stack([row, col], dim=0), jacc, (num_nodes, num_nodes)\n",
    "        ).coalesce()\n",
    "        return jacc_sparse\n",
    "\n",
    "    # (kept for parity with your snippet)\n",
    "    def compute_adaptive_adj(adj, S_class):\n",
    "        identity = torch.eye(adj.size(0), device=adj.device)\n",
    "        sparse_I = torch.sparse_coo_tensor(torch.arange(adj.size(0), device=adj.device).repeat(2,1),\n",
    "                                           torch.ones(adj.size(0), device=adj.device),\n",
    "                                           adj.size())\n",
    "        hetero_adj = S_class * adj + (1 - S_class) * (sparse_I - adj)\n",
    "        return hetero_adj\n",
    "gnn_hetero = Hetero_Graph_Attention_Layer(\n",
    "    in_features=in_dim, out_features=hid_dim, dropout=0.1, alpha=0.2, num_layers=2\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Data(x=[183, 1703], edge_index=[2, 325], y=[183], train_mask=[183, 10], val_mask=[183, 10], test_mask=[183, 10])"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "adj tensor(indices=tensor([[0, 1, 2],\n",
      "                       [1, 2, 0]]),\n",
      "       values=tensor([1., 1., 1.]),\n",
      "       size=(3, 3), nnz=3, layout=torch.sparse_coo)\n",
      "attention_scores torch.Size([3])\n",
      "similarity_softmax tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],\n",
      "                       [0, 1, 2, 0, 1, 2, 0, 1, 2]]),\n",
      "       values=tensor([7.6965e-02, 8.4607e-01, 7.6965e-02, 3.0477e-01,\n",
      "                      3.0477e-01, 3.9046e-01, 9.9862e-01, 6.8851e-04,\n",
      "                      6.8851e-04]),\n",
      "       size=(3, 3), nnz=9, layout=torch.sparse_coo, grad_fn=<ToSparseBackward1>)\n",
      "structure_similarity_softmax tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],\n",
      "                       [0, 1, 2, 0, 1, 2, 0, 1, 2]]),\n",
      "       values=tensor([0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333,\n",
      "                      0.3333, 0.3333]),\n",
      "       size=(3, 3), nnz=9, layout=torch.sparse_coo)\n",
      "attention_scores tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2],\n",
      "                       [0, 1, 2, 0, 1, 2, 0, 1, 2]]),\n",
      "       values=tensor([0.4103, 1.1794, 0.4103, 0.6381, 0.6381, 0.7238, 1.3320,\n",
      "                      0.3340, 0.3340]),\n",
      "       size=(3, 3), nnz=9, layout=torch.sparse_coo, grad_fn=<AddBackward0>)\n",
      "Output node embeddings:\n",
      " tensor([[ 0.0000,  3.2374,  0.2248,  ...,  1.3425,  1.8912, -2.4821],\n",
      "        [ 3.6114,  0.5081,  2.9477,  ..., -0.1079,  0.0000,  0.0167],\n",
      "        [ 8.5343,  5.9848,  6.9663,  ...,  4.9305,  7.2758, -5.9319]],\n",
      "       grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "hetereo= Hetero_Graph_Attention_Layer(in_features=1703, out_features=360, dropout=0.1, alpha=0.2, num_layers=2)\n",
    "# Example usage\n",
    "h = torch.randn( 3, 1703)  # Batch size of 1, 3 nodes, 16 features each\n",
    "edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])  # Example edge index\n",
    "h_prime = hetereo(h, edge_index)\n",
    "print(\"Output node embeddings:\\n\", h_prime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "#Rwritten encoder for stability\n",
    "class HeteroGraphAttentionLayer(nn.Module):\n",
    "    \"\"\"\n",
    "    Input:\n",
    "      h:          [N, Fin]\n",
    "      edge_index: [2, E] (src -> dst), zero-based node indices\n",
    "    Output:\n",
    "      out:        [N, Fout]\n",
    "    \"\"\"\n",
    "    def __init__(self, in_features, out_features, dropout=0.1, alpha=0.2):\n",
    "        super().__init__()\n",
    "        self.in_features = in_features\n",
    "        self.out_features = out_features\n",
    "        self.dropout = dropout\n",
    "        self.alpha = alpha  # (unused in this minimal example; could weight the two scores)\n",
    "\n",
    "        # Linear projection\n",
    "        self.W = nn.Parameter(torch.empty(in_features, out_features))\n",
    "        nn.init.xavier_uniform_(self.W, gain=1.414)\n",
    "\n",
    "        # MLP that scores (Wh_i || Wh_j)\n",
    "        self.attention_mlp = nn.Sequential(\n",
    "            nn.Linear(2 * out_features, 16),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(16, 1)\n",
    "        )\n",
    "        self.reset_parameters()\n",
    "\n",
    "    def reset_parameters(self):\n",
    "        nn.init.xavier_uniform_(self.W, gain=1.414)\n",
    "        for layer in self.attention_mlp:\n",
    "            if isinstance(layer, nn.Linear):\n",
    "                nn.init.xavier_uniform_(layer.weight, gain=1.414)\n",
    "                if layer.bias is not None:\n",
    "                    nn.init.zeros_(layer.bias)\n",
    "\n",
    "    @torch.no_grad()\n",
    "    def _edge_jaccard(self, edge_index, num_nodes):\n",
    "        \"\"\"\n",
    "        Dense, edge-wise Jaccard similarity on an unweighted graph:\n",
    "            J(u,v) = |N(u) ∩ N(v)| / |N(u) ∪ N(v)|\n",
    "        Returns: [E] in [0,1]\n",
    "        \"\"\"\n",
    "        row, col = edge_index  # [E], [E]\n",
    "        device = row.device\n",
    "        # Build dense adjacency (for clarity; replace for large graphs)\n",
    "        idx = torch.stack([row, col], dim=0)\n",
    "        vals = torch.ones(row.numel(), device=device)\n",
    "        A = torch.sparse_coo_tensor(idx, vals, (num_nodes, num_nodes)).to_dense()  # [N,N]\n",
    "        deg = A.sum(dim=1)  # [N]\n",
    "\n",
    "        common = (A @ A)[row, col]  # [E]: # common neighbors between row and col\n",
    "        union = deg[row] + deg[col] - common\n",
    "        jacc = common / (union.clamp_min(1e-9))\n",
    "        return jacc\n",
    "\n",
    "    def forward(self, h, edge_index):\n",
    "        \"\"\"\n",
    "        h: [N, Fin], edge_index: [2, E]\n",
    "        \"\"\"\n",
    "        assert h.dim() == 2, \"Expected h to be [num_nodes, in_features]\"\n",
    "        N = h.size(0)\n",
    "        src, dst = edge_index  # messages flow src -> dst\n",
    "\n",
    "        Wh = h @ self.W  # [N, Fout]\n",
    "\n",
    "        # --- (1) Feature-based attention score on edges -----------------------\n",
    "        e_input = torch.cat([Wh[dst], Wh[src]], dim=-1)  # [E, 2*Fout] (dst first is common in GAT)\n",
    "        e_sim = self.attention_mlp(e_input).squeeze(-1)  # [E]\n",
    "\n",
    "        # Softmax over incoming edges per destination node i (standard GAT)\n",
    "        # Compute softmax grouped by 'dst' without external libs:\n",
    "        #   alpha_ij = exp(e_ij - max_i) / sum_j exp(e_ij - max_i), for edges with same dst\n",
    "        max_per_dst = torch.full((N,), -float(\"inf\"), device=h.device)\n",
    "        max_per_dst.scatter_reduce_(0, dst, e_sim, reduce='amax', include_self=True)\n",
    "        e_centered = e_sim - max_per_dst[dst]\n",
    "        exp_e = torch.exp(e_centered)\n",
    "        sum_per_dst = torch.zeros(N, device=h.device).scatter_add_(0, dst, exp_e)\n",
    "        sim_alpha = exp_e / (sum_per_dst[dst].clamp_min(1e-9))  # [E]\n",
    "\n",
    "        # --- (2) Structure similarity (Jaccard) on edges, then normalize per dst\n",
    "        with torch.no_grad():\n",
    "            jacc = self._edge_jaccard(edge_index, N)  # [E] in [0,1]\n",
    "\n",
    "        # Normalize jaccard per dst with the same grouped softmax trick (optional choice)\n",
    "        max_j_per_dst = torch.full((N,), -float(\"inf\"), device=h.device)\n",
    "        max_j_per_dst.scatter_reduce_(0, dst, jacc, reduce='amax', include_self=True)\n",
    "        j_centered = jacc - max_j_per_dst[dst]\n",
    "        exp_j = torch.exp(j_centered)\n",
    "        sumj_per_dst = torch.zeros(N, device=h.device).scatter_add_(0, dst, exp_j)\n",
    "        struct_alpha = exp_j / (sumj_per_dst[dst].clamp_min(1e-9))  # [E]\n",
    "\n",
    "        # --- (3) Combine the two attentions and renormalize per dst -----------\n",
    "        # You can weight them, e.g., w in [0,1]. Here we average then re-softmax.\n",
    "        combined_logit = torch.log(sim_alpha.clamp_min(1e-9)) + torch.log(struct_alpha.clamp_min(1e-9))\n",
    "        # Re-softmax over dst:\n",
    "        max_c_per_dst = torch.full((N,), -float(\"inf\"), device=h.device)\n",
    "        max_c_per_dst.scatter_reduce_(0, dst, combined_logit, reduce='amax', include_self=True)\n",
    "        c_centered = combined_logit - max_c_per_dst[dst]\n",
    "        exp_c = torch.exp(c_centered)\n",
    "        sumc_per_dst = torch.zeros(N, device=h.device).scatter_add_(0, dst, exp_c)\n",
    "        alpha = exp_c / (sumc_per_dst[dst].clamp_min(1e-9))  # [E]\n",
    "\n",
    "        # --- (4) Message passing: out_i = sum_{j in N(i)} alpha_ij * Wh_j -----\n",
    "        out = torch.zeros(N, self.out_features, device=h.device, dtype=Wh.dtype)\n",
    "        out.index_add_(0, dst, Wh[src] * alpha.unsqueeze(-1))\n",
    "        out = F.dropout(out, p=self.dropout, training=self.training)\n",
    "        return out\n",
    "# ⛔️ Silence all warnings (NumPy, PyTorch, scikit-learn, …)\n",
    "import warnings, numpy as np, os\n",
    "warnings.filterwarnings(\"ignore\")                   # Python / sklearn / PyTorch\n",
    "np.seterr(all=\"ignore\")                             # NumPy runtime warnings\n",
    "os.environ[\"PYTHONWARNINGS\"] = \"ignore\"             # fallback for spawned threads\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1703"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "in_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output node embeddings:\n",
      " tensor([[ 2.0254, -1.8245,  0.0000,  ..., -2.9542, -1.0846,  0.0000],\n",
      "        [ 2.3997, -0.0706, -0.4407,  ...,  2.3256, -0.2691, -2.3474],\n",
      "        [-1.4309, -2.4621,  0.3186,  ..., -2.2328,  1.7604, -3.7888]],\n",
      "       grad_fn=<MulBackward0>)\n"
     ]
    }
   ],
   "source": [
    "#example usage\n",
    "hetereo= HeteroGraphAttentionLayer(\n",
    "    in_features=in_dim, out_features=364, dropout=0.1\n",
    ")\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "# Example usage\n",
    "h = torch.randn(3, in_dim)  # Batch size of 1, 3 nodes, 16 features each\n",
    "edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]])  # Example edge index\n",
    "h_prime = hetereo(h, edge_index)\n",
    "print(\"Output node embeddings:\\n\", h_prime)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "\n",
    "def loss_diversity_from_S(S_assign_list, device=None, eps=1e-9):\n",
    "    \"\"\"\n",
    "    L_div = sum_{ℓ} (1/|V^(ℓ)|) * sum_i H(row_i),\n",
    "    where H(p) = - sum_k p_k log p_k.\n",
    "    \"\"\"\n",
    "    L_div = 0.0\n",
    "    for S in S_assign_list:\n",
    "        # S may be np.ndarray; move to torch\n",
    "        if isinstance(S, np.ndarray):\n",
    "            S_t = torch.from_numpy(S)\n",
    "        else:\n",
    "            S_t = S\n",
    "        if device is not None:\n",
    "            S_t = S_t.to(device)\n",
    "        S_t = S_t.clamp_min(eps)\n",
    "        row_entropy = -(S_t * S_t.log()).sum(dim=1)  # [N_l]\n",
    "        L_div = L_div + row_entropy.mean()\n",
    "    return L_div\n",
    "\n",
    "def loss_reconstruction_from_treeG(treeG, device=None):\n",
    "    \"\"\"\n",
    "    L_rec = sum_{levels} || H^(ℓ) - U^(ℓ)^T (U^(ℓ) H^(ℓ)) ||_F^2,\n",
    "    where U^(ℓ) is built from treeG[ℓ]['u'] (list of N_l vectors of length N_l).\n",
    "    \"\"\"\n",
    "    L_rec = 0.0\n",
    "    for lvl in range(len(treeG)):\n",
    "        if 'u' not in treeG[lvl]:\n",
    "            continue\n",
    "        u_list = treeG[lvl]['u']\n",
    "        # Some levels may store None (skip safely)\n",
    "        if u_list is None or any(v is None for v in u_list):\n",
    "            continue\n",
    "\n",
    "        U_np = np.stack(u_list, axis=0)  # [N_l, N_l]\n",
    "        H_np = treeG[lvl]['features']    # [N_l, D]\n",
    "\n",
    "        U = torch.from_numpy(U_np.astype(np.float32))\n",
    "        H = torch.from_numpy(H_np.astype(np.float32))\n",
    "        if device is not None:\n",
    "            U = U.to(device)\n",
    "            H = H.to(device)\n",
    "\n",
    "        H_hat = U.t() @ (U @ H)          # U^T U H\n",
    "        # Frobenius norm squared\n",
    "        L_rec = L_rec + F.mse_loss(H_hat, H, reduction='mean')\n",
    "    L_rec = L_rec / len(treeG)\n",
    "    return L_rec\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6057  Div=0.9847, Rec=0.4372  Total=1.4268\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6125  Div=0.9457, Rec=0.5851  Total=1.4431\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6066  Div=1.0892, Rec=0.3494  Total=1.4292\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6063  Div=1.1325, Rec=0.6072  Total=1.4590\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6060  Div=0.9407, Rec=0.8619  Total=1.4650\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6057  Div=0.9453, Rec=0.5442  Total=1.4335\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6164  Div=1.1199, Rec=0.5977  Total=1.4649\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6071  Div=1.1130, Rec=0.3382  Total=1.4308\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.6057  Div=1.0993, Rec=1.1898  Total=1.5134\n",
      "--- Coarsening Level 1 ---\n",
      "num_clusters = 55, nodes = 183\n",
      "--- Coarsening Level 2 ---\n",
      "num_clusters = 17, nodes = 55\n",
      "--- Coarsening Level 3 ---\n",
      "num_clusters = 6, nodes = 17\n",
      "--- Coarsening Level 4 ---\n",
      "num_clusters = 1, nodes = 6\n",
      "Graph 0: CE=1.5471  Div=0.9510, Rec=1.6813  Total=1.5009\n"
     ]
    }
   ],
   "source": [
    "lambda_div = 0.1 # set your λ_div\n",
    "lambda_rec = 0.1 # set your λ_rec\n",
    "\n",
    "for epoch in range(10):\n",
    "    model.train()\n",
    "    total_loss = 0.0\n",
    "\n",
    "    for batch in train_loader:\n",
    "        if batch.x is None:\n",
    "            batch.x = torch.ones(batch.num_nodes, 1)\n",
    "\n",
    "        X_list, edge_index_list, y_list = split_batch_to_graphs(batch)\n",
    "\n",
    "        (U_batch, eidx_batch, n_nodes_batch, n_edges_batch,\n",
    "         feats_batch, tree_batch, S_batch) = Uext_batch_from_tree_lists(\n",
    "            X_list, edge_index_list, hetereo,\n",
    "            levels=5, ratio=0.3, temp=0.1, tau=0.5\n",
    "        )\n",
    "\n",
    "        for i in range(len(U_batch)):\n",
    "            logits = model2(U_batch[i], feats_batch[i], tree_batch[i])\n",
    "            logits = logits.mean(dim=0)\n",
    "\n",
    "            y_i = y_list[i]\n",
    "            device = logits.device\n",
    "\n",
    "            # Core CE loss\n",
    "            L_ce = F.cross_entropy(logits, y_i)\n",
    "\n",
    "            # NEW: auxiliary losses\n",
    "            L_div = loss_diversity_from_S(S_batch[i], device=device)\n",
    "            L_rec = loss_reconstruction_from_treeG(tree_batch[i], device=device)\n",
    "\n",
    "            L_total = 0.8*L_ce + lambda_div * L_div + lambda_rec * L_rec\n",
    "\n",
    "            opt.zero_grad()\n",
    "            L_total.backward()\n",
    "            opt.step()\n",
    "\n",
    "            print (f\"Graph {i}: CE={L_ce.item():.4f}  Div={L_div.item():.4f}, Rec={L_rec.item():.4f}  Total={L_total.item():.4f}\")\n",
    "            \n",
    "\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
