# benchmarking_utils.py
"""
Utilities for node embedding benchmarking across datasets:
- dataset loaders (Cora, CiteSeer, PubMed, WikiCS, Amazon-Photo, Arxiv, MAG)
- embedding generators: deepwalk, node2vec, vgae, dgi, fuse (FUSE = semi-supervised modularity),
  random, given (projected to 150 dim if needed)
- classifiers: GCN, GAT, GraphSAGE (Spektral)
- training, evaluation, saving utilities
- run_benchmark() orchestrator
"""
import os
# Force TensorFlow to use CPU only — completely disable GPU visibility
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import time
import random
import numpy as np
import networkx as nx
from tqdm import tqdm
import scipy.sparse as sp
from scipy.sparse import lil_matrix, csr_matrix

# embedding libs
from node2vec import Node2Vec
import torch
from torch_geometric.nn import DeepGraphInfomax, VGAE as PyG_VGAE
from torch_geometric.nn import GCNConv as PyG_GCNConv
from torch_geometric.data import Data as PyGData

# Spektral + TF for classifiers (used in your original code)
import tensorflow as tf
from spektral.layers import GCNConv, GATConv, GraphSageConv
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.decomposition import TruncatedSVD

# Torch-geo datasets + Spektral/similar dataset imports
from torch_geometric.datasets import Planetoid, WikiCS, Amazon
from spektral.datasets import Cora

import logging
logging.getLogger("gensim").setLevel(logging.ERROR)

# Suppress Python warnings
import warnings
warnings.filterwarnings('ignore')


# As a double safety layer, tell TF not to use GPU even if visible.
tf.config.set_visible_devices([], 'GPU')


# ----------------------------
# Configurable default
# ----------------------------
DEFAULT_EMB_DIM = 150



def load_dataset(dataset_name, root="."):
    """
    Load dataset by name. Returns a dict:
    {
      'x': numpy array features,
      'a': scipy sparse adjacency (csr),
      'y': one-hot labels (numpy),
      'labels': integer labels (numpy),
      'G': networkx Graph,
      'pyg_data': torch_geometric.data.Data (x tensor, edge_index, y tensor)
    }
    dataset_name options: 'cora', 'citeseer', 'pubmed', 'wikics', 'photo' (amazon-photo)
    """
    name = dataset_name.lower()
    if name == "cora":
        data = Cora()
        graph = data.graphs[0]
        x = graph.x
        a = graph.a.tocsr() if sp.issparse(graph.a) else csr_matrix(graph.a)
        y_onehot = graph.y
        labels = np.argmax(y_onehot, axis=1)
    elif name in ("citeseer", "pubmed"):
        # Planetoid provides CiteSeer & PubMed
        data = Planetoid(root=root, name=dataset_name.capitalize())
        d = data[0]
        x = d.x.numpy()
        edge_index = d.edge_index.numpy()
        labels = d.y.numpy()
        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for i in range(edge_index.shape[1]):
            s, t = edge_index[:, i]
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()
        num_classes = labels.max() + 1
        y_onehot = np.eye(num_classes)[labels]
    elif name == "wikics":
        data = WikiCS(root=root)
        d = data[0]
        x = d.x.numpy()
        edge_index = d.edge_index.numpy()
        labels = d.y.numpy()
        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for i in range(edge_index.shape[1]):
            s, t = edge_index[:, i]
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()
        num_classes = labels.max() + 1
        y_onehot = np.eye(num_classes)[labels]
    elif name in ("photo", "amazon-photo", "amazon_photos", "amazon_photos"):
        # torch_geometric Amazon dataset (photo)
        data = Amazon(root=root, name="photo")
        d = data[0]
        x = d.x.numpy()
        edge_index = d.edge_index.numpy()
        labels = d.y.numpy()
        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for i in range(edge_index.shape[1]):
            s, t = edge_index[:, i]
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()
        num_classes = labels.max() + 1
        y_onehot = np.eye(num_classes)[labels]
    elif name == "arxiv":
        from ogb.nodeproppred import PygNodePropPredDataset
        from torch_geometric.data import Data
        from torch_geometric.data.data import DataEdgeAttr, DataTensorAttr
        from torch_geometric.data.storage import GlobalStorage
        # from torch.serialization import safe_globals
    
        dataset_path = '/Users/user/Benchmarking_fuse'
        os.makedirs(dataset_path, exist_ok=True)
    
        # Load dataset safely to avoid unpickling issues
        with safe_globals([Data, DataEdgeAttr, DataTensorAttr, GlobalStorage]):
            dataset = PygNodePropPredDataset(name="ogbn-arxiv", root=dataset_path)
    
        data = dataset[0]
        data.y = data.y.squeeze()  # Flatten labels
    
        x = data.x.numpy()  # Use the given features!
        edge_index = data.edge_index.numpy().T  # Shape: [num_edges, 2]
        labels_int = data.y.numpy()
    
        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for s, t in edge_index:
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()
    
        num_classes = labels_int.max() + 1
        y_onehot = np.eye(num_classes)[labels_int]
    
        pyg = PyGData(
            x=torch.tensor(x, dtype=torch.float),
            edge_index=torch.tensor(edge_index.T, dtype=torch.long),  # Shape [2, num_edges]
            y=torch.tensor(labels_int, dtype=torch.long)
        )
    
        G = nx.from_scipy_sparse_array(a)
    
        return {
            "x": np.array(x, dtype=float),
            "a": a,
            "y": np.array(y_onehot, dtype=float),
            "labels": np.array(labels_int, dtype=int),
            "G": G,
            "pyg_data": pyg
        }
        
    elif name == "mag":
        from ogb.nodeproppred import PygNodePropPredDataset
        from torch_geometric.data import Data
    
        dataset_path = '/Users/user/Benchmarking_fuse'
        os.makedirs(dataset_path, exist_ok=True)
    
        dataset = PygNodePropPredDataset(name="ogbn-mag", root=dataset_path)
        data = dataset[0]
    
        print(" Detected dict-style MAG dataset (x_dict / y_dict / edge_index_dict).")
    
        # Extract only the "paper" nodes and "cites" edges
        x = data.x_dict['paper'].numpy()
        labels_int = data.y_dict['paper'].squeeze().numpy()
        edge_index = data.edge_index_dict[('paper', 'cites', 'paper')].numpy().T
    
        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for s, t in edge_index:
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()
    
        num_classes = int(labels_int.max()) + 1
        y_onehot = np.eye(num_classes)[labels_int]
    
        pyg = PyGData(
            x=torch.tensor(x, dtype=torch.float),
            edge_index=torch.tensor(edge_index.T, dtype=torch.long),
            y=torch.tensor(labels_int, dtype=torch.long)
        )
    
        G = nx.from_scipy_sparse_array(a)
    
        return {
            "x": np.array(x, dtype=float),
            "a": a,
            "y": np.array(y_onehot, dtype=float),
            "labels": np.array(labels_int, dtype=int),
            "G": G,
            "pyg_data": pyg
        }

    elif name in ("products", "ogbn-products", "product", "amazon-products"):
        from ogb.nodeproppred import PygNodePropPredDataset
        from torch_geometric.data import Data

        dataset_path = '/Users/user/Benchmarking_fuse'
        os.makedirs(dataset_path, exist_ok=True)

        dataset = PygNodePropPredDataset(name="ogbn-products", root=dataset_path)
        data = dataset[0]

        # Homogeneous graph: data.x, data.y, data.edge_index
        x = data.x.numpy()
        labels_int = data.y.squeeze().numpy()
        edge_index = data.edge_index.numpy().T  # shape [num_edges, 2]

        num_nodes = x.shape[0]
        a = lil_matrix((num_nodes, num_nodes), dtype=np.float32)
        for s, t in edge_index:
            a[s, t] = 1
            a[t, s] = 1
        a = a.tocsr()

        num_classes = int(labels_int.max()) + 1
        y_onehot = np.eye(num_classes)[labels_int]

        pyg = PyGData(
            x=torch.tensor(x, dtype=torch.float),
            edge_index=torch.tensor(edge_index.T, dtype=torch.long),
            y=torch.tensor(labels_int, dtype=torch.long)
        )

        G = nx.from_scipy_sparse_array(a)

        return {
            "x": np.array(x, dtype=float),
            "a": a,
            "y": np.array(y_onehot, dtype=float),
            "labels": np.array(labels_int, dtype=int),
            "G": G,
            "pyg_data": pyg
        }



# ----------------------------
# Masking utility (allow import from file)
# ----------------------------
def create_label_mask(labels, mask_frac=0.7, seed=None, mask_indices_path=None):
    """
    Create masked labels according to mask_frac (fraction of nodes to MASK).
    If mask_indices_path is given, load array of indices from that path and use it as the masked indices.

    Returns:
      masked_labels: int array where masked positions are -1, known positions contain label ints
      label_mask: boolean array True for KNOWN labels
      labels_to_be_masked: indices of masked nodes (as np.array)
    """
    n = len(labels)
    rng = np.random.RandomState(seed) if seed is not None else np.random
    if mask_indices_path is not None:
        idx = np.load(mask_indices_path)
        labels_to_be_masked = np.array(idx, dtype=int)
    else:
        # Fraction masked = mask_frac → fraction known = 1 - mask_frac
        k = int(round(n * mask_frac))
        labels_to_be_masked = rng.choice(np.arange(n), size=k, replace=False)

    masked = np.full(n, -1, dtype=int)
    mask_set = set(labels_to_be_masked.tolist())
    for i in range(n):
        if i not in mask_set:
            masked[i] = labels[i]
    label_mask = masked != -1
    return masked, label_mask, labels_to_be_masked



# ----------------------------
# Embedding generation functions
# ----------------------------
def deepwalk_embedding(G, k=DEFAULT_EMB_DIM, workers=1, p=1, q=1, seed=None, walk_length=5, num_walks=10):
    """
    DeepWalk embedding using Node2Vec with p=q=1.
    """
    node2vec = Node2Vec(
        G,
        dimensions=k,
        walk_length=walk_length,
        num_walks=num_walks,
        workers=workers,
        p=1,
        q=1,
        seed=seed
    )
    model = node2vec.fit()
    return np.vstack([model.wv[str(n)] for n in G.nodes()])


def node2vec_embedding(G, k=DEFAULT_EMB_DIM, workers=1, p=0.5, q=2, seed=None, walk_length=5, num_walks=10):
    node2vec = Node2Vec(
        G,
        dimensions=k,
        walk_length=walk_length,
        num_walks=num_walks,
        workers=workers,
        p=p,
        q=q,
        seed=seed
    )
    model = node2vec.fit()
    return np.vstack([model.wv[str(n)] for n in G.nodes()])


def random_embedding(n_nodes, k=DEFAULT_EMB_DIM, seed=None):
    rng = np.random.RandomState(seed) if seed is not None else np.random
    return rng.randn(n_nodes, k)

def given_embedding(features):
    """
    If given features differ from required k, project using TruncatedSVD (if larger) or pad with zeros (if smaller).
    """
    X = np.array(features, dtype=float)
    return X

# VGAE & DGI (PyG) — default-ish training loops, one-hot features used as in original notebooks
def vgae_embedding(pyg_data, k=DEFAULT_EMB_DIM, epochs=200, device='cpu'):
    device = torch.device(device)
    num_nodes = pyg_data.num_nodes
    x = torch.randn(num_nodes, k, device=device)
    edge_index = pyg_data.edge_index.to(device)

    class Encoder(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv1 = PyG_GCNConv(in_channels, 2 * out_channels)
            self.conv_mu = PyG_GCNConv(2 * out_channels, out_channels)
            self.conv_logstd = PyG_GCNConv(2 * out_channels, out_channels)
        def forward(self, x, edge_index):
            x = torch.relu(self.conv1(x, edge_index))
            mu = self.conv_mu(x, edge_index)
            logstd = self.conv_logstd(x, edge_index)
            return mu, logstd

    model = PyG_VGAE(Encoder(num_nodes, k)).to(device)
    opt = torch.optim.Adam(model.parameters(), lr=0.01)
    model.train()
    for _ in range(epochs):
        opt.zero_grad()
        z = model.encode(x, edge_index)
        loss = model.recon_loss(z, edge_index) + (1.0 / num_nodes) * model.kl_loss()
        loss.backward()
        opt.step()
    model.eval()
    with torch.no_grad():
        z = model.encode(x, edge_index)
    return z.detach().cpu().numpy()

def dgi_embedding(pyg_data, k=150, epochs=200, device='cpu'):
    device = torch.device(device)
    num_nodes = pyg_data.num_nodes
    x = torch.randn(num_nodes, k, device=device)  # Random initialization

    edge_index = pyg_data.edge_index.to(device)

    class GCNEncoder(torch.nn.Module):
        def __init__(self, in_channels, out_channels):
            super().__init__()
            self.conv1 = PyG_GCNConv(in_channels, 2 * out_channels)
            self.conv2 = PyG_GCNConv(2 * out_channels, out_channels)

        def forward(self, x, edge_index):
            x = torch.relu(self.conv1(x, edge_index))
            return self.conv2(x, edge_index)

    model = DeepGraphInfomax(
        hidden_channels=k,
        encoder=GCNEncoder(in_channels=k, out_channels=k),
        summary=lambda z, *args, **kwargs: torch.mean(z, dim=0),
        corruption=lambda x, edge_index: (x[torch.randperm(x.size(0))], edge_index)
    ).to(device)

    opt = torch.optim.Adam(model.parameters(), lr=0.01)
    model.train()
    for _ in range(epochs):
        opt.zero_grad()
        pos_z, neg_z, summary = model(x, edge_index)
        loss = model.loss(pos_z, neg_z, summary)
        loss.backward()
        opt.step()

    model.eval()
    with torch.no_grad():
        pos_z, _, _ = model(x, edge_index)

    return pos_z.detach().cpu().numpy()



def fuse_unsupervised(G, k=150, eta=0.01,  lambda_unsupervised=1e9, iterations=200, seed=None):

    if seed is not None:
        np.random.seed(seed)
        random.seed(seed)
    
    # Convert graph to sparse adjacency matrix
    A = csr_matrix(nx.to_scipy_sparse_array(G, format='csr'))
    degrees = np.array(A.sum(axis=1)).flatten()
    m = G.number_of_edges()
    n = A.shape[0]
    S = np.random.randn(n, k)
    S, _ = np.linalg.qr(S)

    for _ in tqdm(range(iterations), desc="Gradient Ascent with Linear Modularity"):
        # Compute modularity gradient using linear approximation
        neighbor_agg = A @ S  # Efficient aggregation of neighbor embeddings
        global_correction = (degrees[:, None] / (2 * m)) * S.sum(axis=0)
        grad_modularity = (1 / (2 * m)) * (neighbor_agg - global_correction) * lambda_unsupervised

        # Update embeddings
        grad_total = lambda_unsupervised * grad_modularity
        S += eta * grad_total
        S, _ = np.linalg.qr(S)

    return S


class CPU_Sparse_GCNConv(GCNConv):
    def call(self, inputs, training=None, mask=None):
        with tf.device("/CPU:0"):
            return super().call(inputs, mask=None)

class CPU_Sparse_GATConv(GATConv):
    def call(self, inputs, training=None, mask=None):
        with tf.device("/CPU:0"):
            return super().call(inputs, mask=None)

class CPU_Sparse_SAGEConv(GraphSageConv):
    def call(self, inputs, training=None, mask=None):
        with tf.device("/CPU:0"):
            return super().call(inputs)  



class GCN(tf.keras.Model):
    def __init__(self, n_labels, seed=42):
        super().__init__()
        initializer = tf.keras.initializers.GlorotUniform(seed=seed)

        self.conv1 = CPU_Sparse_GCNConv(16, activation='relu',
                                        kernel_initializer=initializer)
        self.conv2 = CPU_Sparse_GCNConv(n_labels, activation='softmax',
                                        kernel_initializer=initializer)

    def call(self, inputs, training=False):
        x, a = inputs
        h1 = self.conv1([x, a])
        out = self.conv2([h1, a])
        return out, h1


class GAT(tf.keras.Model):
    def __init__(self, n_labels, num_heads=8, seed=42):
        super().__init__()
        initializer = tf.keras.initializers.GlorotUniform(seed=seed)

        self.conv1 = CPU_Sparse_GATConv(
            16, attn_heads=num_heads, concat_heads=True,
            activation='elu', kernel_initializer=initializer
        )
        self.conv2 = CPU_Sparse_GATConv(
            n_labels, attn_heads=1, concat_heads=False,
            activation='softmax', kernel_initializer=initializer
        )

    def call(self, inputs, training=False):
        x, a = inputs
        h1 = self.conv1([x, a])
        out = self.conv2([h1, a])
        return out, h1

class GraphSAGE(tf.keras.Model):
    def __init__(self, n_labels, hidden_dim=16, aggregator='mean', seed=42):
        super().__init__()
        initializer = tf.keras.initializers.GlorotUniform(seed=seed)

        self.conv1 = CPU_Sparse_SAGEConv(
            hidden_dim, activation='relu', aggregator=aggregator,
            kernel_initializer=initializer
        )
        self.conv2 = CPU_Sparse_SAGEConv(
            n_labels, activation='softmax', aggregator=aggregator,
            kernel_initializer=initializer
        )

    def call(self, inputs, training=False):
        x, a = inputs
        h1 = self.conv1([x, a])
        out = self.conv2([h1, a])
        return out, h1




# ----------------------------
# Training / evaluation helpers
# ----------------------------
def sparse_to_tf_sparse(A_csr):
    """
    Convert scipy csr to tf.sparse.SparseTensor (ordered).
    """
    A_coo = A_csr.tocoo()
    indices = np.vstack([A_coo.row, A_coo.col]).T
    values = A_coo.data.astype(np.float32)
    shape = A_coo.shape
    A_sp = tf.sparse.SparseTensor(indices=indices, values=values, dense_shape=shape)
    A_sp = tf.sparse.reorder(A_sp)
    return A_sp

def evaluate_preds(true_int_labels, pred_int_labels):
    acc = accuracy_score(true_int_labels, pred_int_labels)
    f1 = f1_score(true_int_labels, pred_int_labels, average='macro')
    cm = confusion_matrix(true_int_labels, pred_int_labels)
    return {"accuracy": float(acc), "f1_score": float(f1), "confusion_matrix": cm}



def train_and_evaluate_classifier(embedding_matrix, adjacency_csr, labels_onehot, labels_int, label_mask,
                                  classifier_name='gcn', epochs=200, seed=42, verbose=False):
    """
    Train the chosen classifier using embedding_matrix as features and adjacency.
    label_mask: boolean array True for nodes whose labels are KNOWN (used in training)
    Returns:
      results dict with accuracy/f1 on masked nodes, training_time_seconds, predictions for all nodes (ints)
    """
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    n_nodes = embedding_matrix.shape[0]
    num_classes = labels_onehot.shape[1]
    
    # Get training indices (nodes with known labels)
    train_idx = np.where(label_mask)[0]
    
    # Create training subgraph
    X_train = tf.convert_to_tensor(np.array(embedding_matrix[train_idx], dtype=np.float32))
    y_train = labels_onehot[train_idx]
    
    # Reduce the adjacency matrix to only include training nodes
    A_train = adjacency_csr[train_idx, :][:, train_idx]
    A_train_sp = sparse_to_tf_sparse(A_train)

    # instantiate model
    if classifier_name.lower() == 'gcn':
        model = GCN(num_classes, seed=seed)
    elif classifier_name.lower() == 'gat':
        model = GAT(num_classes, seed=seed)
    elif classifier_name.lower() in ('graphsage', 'sage', 'graph_sage'):
        model = GraphSAGE(num_classes, seed=seed)
    else:
        raise ValueError("Unknown classifier: " + classifier_name)

    optimizer = tf.keras.optimizers.Adam(learning_rate=1e-2)
    loss_fn = tf.keras.losses.CategoricalCrossentropy()

    num_train = len(train_idx)
    num_test = len(labels_int) - num_train
    if verbose:
        print(f"[Classifier={classifier_name}] Train={num_train}, Test={num_test}")

    t0 = time.time()
    # training loop
    for epoch in range(epochs):
        with tf.GradientTape() as tape:
            preds, _ = model([X_train, A_train_sp], training=True)  # (n_train_nodes, num_classes)
            loss = loss_fn(y_train, preds)
        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))
        # minimal printing
        if verbose and (epoch % 50 == 0):
            print(f"[{classifier_name}] epoch {epoch}, loss={float(loss):.4f}")
    training_time = time.time() - t0

    # Prepare the full graph for prediction
    X_full = tf.convert_to_tensor(np.array(embedding_matrix, dtype=np.float32))
    A_full_sp = sparse_to_tf_sparse(adjacency_csr)
    
    # Make predictions for all nodes
    preds_all, emb_intermediate = model([X_full, A_full_sp], training=False)
    pred_int = tf.argmax(preds_all, axis=1).numpy()

    # masked nodes = those NOT in train_idx
    masked_idx = np.where(~label_mask)[0]
    true_masked = labels_int[masked_idx]
    pred_masked = pred_int[masked_idx]
    results = evaluate_preds(true_masked, pred_masked)
    results.update({
        "training_time_seconds": float(training_time),
        "classifier": classifier_name,
        "predictions_all": pred_int,
        "masked_indices": masked_idx
    })
    return results


# ----------------------------
# Orchestrator: run single dataset + seed + mask_frac
# ----------------------------
def run_one_experiment(dataset_name, seed, mask_frac, emb_dim=DEFAULT_EMB_DIM,
                       embedding_methods=None, classifiers=None,
                       mask_indices_path=None, vgae_epochs=200, dgi_epochs=200,
                       fuse_iterations=200, force_device='cpu',
                       save_dir="./benchmark_outputs", verbose=False):
    import pickle

    if embedding_methods is None:
        embedding_methods = ['random', 'given', 'deepwalk', 'node2vec', 'vgae', 'dgi', 'fuse']
    if classifiers is None:
        classifiers = ['gcn', 'gat', 'graphsage']

    ds = load_dataset(dataset_name)
    X = ds['x']
    A = ds['a']
    y_onehot = ds['y']
    labels_int = ds['labels']
    G = ds['G']
    pyg = ds['pyg_data']
    n = X.shape[0]

    masked_labels, label_mask, labels_to_be_masked = create_label_mask(
        labels_int, mask_frac, seed=seed, mask_indices_path=mask_indices_path
    )

    num_masked = len(labels_to_be_masked)
    num_unmasked = n - num_masked
    if verbose:
        print(f"[{dataset_name}][seed={seed}][mf={mask_frac}] "
              f"Masked={num_masked}, Unmasked={num_unmasked}")

    embedding_times = []
    embeddings = {}
    run_results = []

    masked_pct = int(mask_frac * 100)
    known_pct = 100 - masked_pct
    folder_name = f"{masked_pct}-{known_pct}"
    dsdir = os.path.join(save_dir, dataset_name, folder_name)
    os.makedirs(dsdir, exist_ok=True)

    for emb_name in embedding_methods:
        if verbose:
            print(f"[{dataset_name}][seed={seed}][mask_frac={mask_frac}] Running {emb_name} …")
    
        filename = f"{emb_name.lower()}_embedding_{masked_pct}_{known_pct}_{seed}.pkl"
        filepath = os.path.join(dsdir, filename)
    
        # ----------------------------------------------------
        # Load cached embedding if available
        # ----------------------------------------------------
        if os.path.exists(filepath):
            if verbose:
                print(f"→ Loading cached embedding: {filepath}")
    
            with open(filepath, "rb") as f:
                E = pickle.load(f)
    
            t_elapsed = 0.0
    
        else:
            # ----------------------------------------------------
            # Compute embedding
            # ----------------------------------------------------
            if verbose:
                print(f"→ Computing embedding: {emb_name}")
    
            tstart = time.time()
    
            if emb_name.lower() == 'random':
                E = random_embedding(n, k=emb_dim, seed=seed)
            elif emb_name.lower() == 'given':
                E = given_embedding(X)
            elif emb_name.lower() == 'deepwalk':
                E = deepwalk_embedding(G, k=emb_dim, seed=seed)
            elif emb_name.lower() == 'node2vec':
                E = node2vec_embedding(G, k=emb_dim, seed=seed)
            elif emb_name.lower() == 'vgae':
                E = vgae_embedding(pyg, k=emb_dim, epochs=vgae_epochs, device=force_device)
            elif emb_name.lower() == 'dgi':
                E = dgi_embedding(pyg, k=emb_dim, epochs=dgi_epochs, device=force_device)
            elif emb_name.lower() in ('fuse', 'modularity'):
                E = fuse_unsupervised(G, k=emb_dim, iterations=fuse_iterations, seed=seed)
            else:
                raise ValueError("Unknown embedding name: " + emb_name)
    
            t_elapsed = time.time() - tstart
    
            # Save embedding only once
            with open(filepath, "wb") as f:
                pickle.dump(E, f)
    
        # Store results
        embedding_times.append((emb_name, t_elapsed))
        embeddings[emb_name] = E
    
        # ----------------------------------------------------
        # Train classifiers
        # ----------------------------------------------------
        for clf in classifiers:
            clf_start = time.time()
    
            res = train_and_evaluate_classifier(
                E, A, y_onehot, labels_int, label_mask,
                classifier_name=clf, epochs=200, seed=seed, verbose=True
            )
    
            clf_elapsed = time.time() - clf_start
    
            run_results.append({
                "dataset": dataset_name,
                "seed": seed,
                "mask_frac": mask_frac,
                "embedding": emb_name,
                "classifier": clf,
                "embedding_time_seconds": float(t_elapsed),
                "classifier_time_seconds": float(clf_elapsed),
                "train_time_seconds": float(res.get("training_time_seconds", np.nan)),
                "accuracy": float(res["accuracy"]),
                "f1_score": float(res["f1_score"]),
            })

            print(run_results)


    return run_results, embedding_times, embeddings



# ----------------------------
# Helper to build mask file path
# ----------------------------
def get_mask_file_path(masks_root, dataset_name, seed, mask_frac):
    """
    Build expected mask file path given dataset, seed, mask_frac (fraction masked).
    Returns None if no such file exists.
    
    Folder structure:
      masks_root/
        AmazonPhotos/
          70_30/   <- 70% masked, 30% known
          30_70/   <- 30% masked, 70% known
        Cora/...
        PubMed/...
        WikiCS/...
        CiteSeer/...
    """
    dataset_map = {
        "cora": "Cora",
        "citeseer": "CiteSeer",
        "pubmed": "PubMed",
        "wikics": "WikiCS",
        "photo": "AmazonPhotos",
        "amazon-photo": "AmazonPhotos",
        "amazon_photos": "AmazonPhotos",
        "arxiv": "Arxiv",
        "mag": "MAG"
    }
    if dataset_name.lower() not in dataset_map:
        return None  # no mapping -> skip
    
    folder_name = dataset_map[dataset_name.lower()]
    if mask_frac == 0.7:
        subfolder = "70_30"  # 70% masked, 30% known
        fname = f"{folder_name}_70_30_masked_indices_seed{seed}.npy"
    elif mask_frac == 0.3:
        subfolder = "30_70"  # 30% masked, 70% known
        fname = f"{folder_name}_30_70_masked_indices_seed{seed}.npy"
    else:
        return None
    
    mask_path = os.path.join(masks_root, folder_name, subfolder, fname)
    return mask_path if os.path.exists(mask_path) else None


# ----------------------------
# High-level run_benchmark driver
# ----------------------------
def run_benchmark(datasets, seeds, mask_fracs=[0.7, 0.3], emb_dim=DEFAULT_EMB_DIM,
                  embedding_methods=None, classifiers=None,
                  vgae_epochs=200, dgi_epochs=200, fuse_iterations=200,
                  save_dir="./benchmark_outputs", device='cpu',
                  masks_root="./mag_masks", verbose=False):
    import pandas as pd
    all_results = []
    all_embedding_times = []

    tasks = [(ds, mf) for ds in datasets for mf in mask_fracs]

    for ds, mf in tqdm(tasks, desc="Benchmark tasks", disable=verbose):
        partial_results = []
        partial_times = []
        for seed in seeds:
            mask_path = get_mask_file_path(masks_root, ds, seed, mf)
            if verbose:
                if mask_path:
                    print(f"Using custom mask: {mask_path}")
                else:
                    print(f"No custom mask for [{ds} seed={seed} mf={mf}] → random mask.")

            rr, et, _ = run_one_experiment(
                ds, seed, mf, emb_dim=emb_dim,
                embedding_methods=embedding_methods,
                classifiers=classifiers,
                vgae_epochs=vgae_epochs,
                dgi_epochs=dgi_epochs,
                fuse_iterations=fuse_iterations,
                force_device=device,
                save_dir=save_dir,
                verbose=verbose,
                mask_indices_path=mask_path
            )
            partial_results.extend(rr)
            for e in et:
                e_rec = {"dataset": ds, "seed": seed, "mask_frac": mf,
                         "embedding": e[0], "embedding_time_seconds": float(e[1])}
                partial_times.append(e_rec)

        masked_pct = int(mf * 100)
        known_pct = 100 - masked_pct
        dsdir = os.path.join(save_dir, ds, f"{masked_pct}-{known_pct}")
        os.makedirs(dsdir, exist_ok=True)

        pd.DataFrame(partial_results).to_csv(os.path.join(dsdir, "per_run_results.csv"), index=False)
        pd.DataFrame(partial_times).to_csv(os.path.join(dsdir, "embedding_times.csv"), index=False)

        all_results.extend(partial_results)
        all_embedding_times.extend(partial_times)

    # Final aggregation
    df_results = pd.DataFrame(all_results)
    df_embtimes = pd.DataFrame(all_embedding_times)

    avg_by_model = df_results.groupby(
        ["dataset", "mask_frac", "embedding", "classifier"]
    ).agg(
        avg_accuracy=("accuracy", "mean"),
        std_accuracy=("accuracy", "std"),
        avg_f1=("f1_score", "mean"),
        std_f1=("f1_score", "std"),
        avg_train_time=("train_time_seconds", "mean"),
        std_train_time=("train_time_seconds", "std"),
        avg_classifier_time=("classifier_time_seconds", "mean"),  # NEW
        std_classifier_time=("classifier_time_seconds", "std"),   # NEW
        n_runs=("accuracy", "count")
    ).reset_index()

    avg_by_model["accuracy_pm"] = avg_by_model.apply(
        lambda r: f"{r['avg_accuracy']:.4f} ± {r['std_accuracy']:.4f}", axis=1
    )
    avg_by_model["f1_pm"] = avg_by_model.apply(
        lambda r: f"{r['avg_f1']:.4f} ± {r['std_f1']:.4f}", axis=1
    )
    avg_by_model["classifier_time_pm"] = avg_by_model.apply(
        lambda r: f"{r['avg_classifier_time']:.4f} ± {r['std_classifier_time']:.4f}", axis=1
    )

    avg_embtime = df_embtimes.groupby(
        ["dataset", "mask_frac", "embedding"]
    ).agg(
        avg_embedding_time=("embedding_time_seconds", "mean"),
        std_embedding_time=("embedding_time_seconds", "std"),
        n_runs=("embedding_time_seconds", "count")
    ).reset_index()

    avg_embtime["embedding_time_pm"] = avg_embtime.apply(
        lambda r: f"{r['avg_embedding_time']:.4f} ± {r['std_embedding_time']:.4f}", axis=1
    )

    os.makedirs(save_dir, exist_ok=True)
    df_results.to_csv(os.path.join(save_dir, "per_run_results_all_mag_fuse_unsup_walk_length_5.csv"), index=False)
    avg_by_model.to_csv(os.path.join(save_dir, "avg_by_model_and_classifier_mag_fuse_unsup_walk_length_5.csv"), index=False)
    avg_embtime.to_csv(os.path.join(save_dir, "avg_embedding_times_mag_fuse_unsup_walk_length_5.csv"), index=False)

    return {
        "per_run": df_results,
        "avg_by_model_and_classifier": avg_by_model,
        "avg_embedding_times": avg_embtime
    }
