from torch_geometric.nn import GINConv, GCNConv, TAGConv, GATConv, GATv2Conv
import torch.nn as nn
import torch.nn.functional as F
import torch
import copy
import numpy as np
import os
import torch_geometric.transforms as T
import pickle
import scipy.sparse as sp
from torch_geometric.utils import to_dense_adj

from sklearn.feature_selection import mutual_info_regression
from sklearn.cluster import AgglomerativeClustering
# from torch_geometric.datasets import Planetoid, Amazon, Planetoid, Amazon, WebKB, WikipediaNetwork, Actor, Coauthor, WikiCS
from torch_geometric.datasets import (
    Planetoid, Amazon, WebKB, WikipediaNetwork, Actor,
    Coauthor, WikiCS, HeterophilousGraphDataset
)

class GCNNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels_list, out_channels):
        super().__init__()
        dims = [in_channels] + hidden_channels_list + [out_channels]
        self.convs = nn.ModuleList()
        for i in range(len(dims) - 2):
            self.convs.append(GCNConv(dims[i], dims[i+1], cached=False, normalize=True))
        self.lin = nn.Linear(dims[-2], dims[-1])

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
        x = self.lin(x)
        return x
    
class GCNNodeClassifier_Het(nn.Module):
    def __init__(self, in_channels, hidden_channels_list, out_channels):
        super().__init__()
        dims = [in_channels] + hidden_channels_list + [out_channels]
        self.convs = nn.ModuleList()
        for i in range(len(dims) - 2):
            self.convs.append(TAGConv(dims[i], dims[i+1], cached=False, normalize=True))
        self.lin = nn.Linear(dims[-2], dims[-1])

    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
        x = self.lin(x)
        return x

class MLPNodeClassifier(nn.Module):
    def __init__(self, in_channels, hidden_channels_list, out_channels):
        super().__init__()
        layers = []
        dims = [in_channels] + hidden_channels_list
        for i in range(len(dims)-1):
            layers += [nn.Linear(dims[i], dims[i+1]), nn.ReLU()]
        layers.append(nn.Linear(dims[-1], out_channels))
        self.net = nn.Sequential(*layers)

    def forward(self, x, edge_index=None):
        return self.net(x)


def set_seed_all(seed=0):
    import random
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def permute_features(data, seed=0):
    g = torch.Generator(device=data.x.device)
    g.manual_seed(seed)
    perm = torch.randperm(data.num_nodes, generator=g, device=data.x.device)
    data_perm = data.clone()
    data_perm.x = data.x[perm]          # permute rows (assign features to different nodes)
    return data_perm



def permute_graph(edge_index, num_nodes=None, seed=0, make_undirected=True):
    """
    Replace the graph with an Erdős–Rényi (G(n, M)) graph that has the same
    number of undirected edges as the original. No self-loops. Symmetrized.

    Args:
        edge_index: LongTensor [2, E]
        num_nodes: int (optional)
        seed: int
        make_undirected: bool (keep True)
    Returns:
        LongTensor [2, E_new]
    """
    device = edge_index.device
    if num_nodes is None:
        num_nodes = int(edge_index.max().item()) + 1

    # Count unique undirected edges in the original graph
    ei = edge_index
    mask = ei[0] != ei[1]
    ei = ei[:, mask]
    a = torch.minimum(ei[0], ei[1])
    b = torch.maximum(ei[0], ei[1])
    key = a * num_nodes + b
    uniq = torch.unique(key)
    m = uniq.numel()  # # undirected edges to sample

    # Reproducible RNG
    g = torch.Generator(device=device)
    g.manual_seed(seed)

    # Rejection-sample undirected pairs until we have m unique edges
    # (Vectorized batches so it's fast even for Cora/CiteSeer/PubMed)
    picked = set()
    need = m
    edges_ab = []

    # choose a batch size roughly proportional to m (but at least 1024)
    batch = max(1024, int(m * 1.3))

    while need > 0:
        u = torch.randint(num_nodes, (batch,), generator=g, device=device)
        v = torch.randint(num_nodes, (batch,), generator=g, device=device)
        keep = u != v
        u, v = u[keep], v[keep]
        a = torch.minimum(u, v)
        b = torch.maximum(u, v)
        keys = (a * num_nodes + b).tolist()
        for (aa, bb, kk) in zip(a.tolist(), b.tolist(), keys):
            if kk not in picked:
                picked.add(kk)
                edges_ab.append((aa, bb))
                need -= 1
                if need == 0:
                    break

    undirected = torch.tensor(edges_ab, dtype=torch.long, device=device).t()  # [2, m]
    # Return with both directions
    ei_sym = torch.cat([undirected, undirected.flip(0)], dim=1)

    if make_undirected:
        # Dedup again defensively (cheap at this size)
        a = torch.minimum(ei_sym[0], ei_sym[1])
        b = torch.maximum(ei_sym[0], ei_sym[1])
        key = a * num_nodes + b

        # Get indices of first occurrence
        uniq_keys, inverse = torch.unique(key, return_inverse=True, sorted=True)
        idx = torch.zeros_like(uniq_keys, dtype=torch.long)
        idx = idx.scatter_(0, inverse, torch.arange(inverse.numel(), device=device))
        idx = idx[:uniq_keys.numel()]

        undirected = torch.stack([a[idx], b[idx]], dim=0)
        ei_sym = torch.cat([undirected, undirected.flip(0)], dim=1)

    return ei_sym



def permute_single_feature(data, feature_idx, seed=0):
    """
    Permutes the values of a single feature column across nodes,
    keeping all other features unchanged.

    Args:
        data: PyG Data object
        feature_idx: int, index of the feature to permute
        seed: int, RNG seed
    Returns:
        A new Data object with one permuted feature column.
    """
    data_perm = data.clone()
    g = torch.Generator(device=data.x.device)
    g.manual_seed(seed)

    # Extract the column to permute
    feature_col = data.x[:, feature_idx]
    perm = torch.randperm(data.num_nodes, generator=g, device=data.x.device)

    # Assign permuted values back
    data_perm.x[:, feature_idx] = feature_col[perm]

    return data_perm


def reset_model_parameters(model):
    for layer in model.modules():
        if hasattr(layer, 'reset_parameters'):
            layer.reset_parameters()


def run_experiment(model, optimizer, data_obj, epochs=500, verbose_every=100):
    # global best_val, best_test_at_val
    best_val, best_test_at_val = 0.0, 0.0
    for epoch in range(1, epochs+1):
        model.train()
        optimizer.zero_grad()
        out = model(data_obj.x, data_obj.edge_index if hasattr(data_obj, 'edge_index') else None)
        loss = F.cross_entropy(out[data_obj.train_mask], data_obj.y[data_obj.train_mask])
        loss.backward()
        optimizer.step()

        model.eval()
        with torch.no_grad():
            logits = model(data_obj.x, getattr(data_obj, 'edge_index', None))
            preds = logits.argmax(-1)
            def acc(mask):
                return (preds[mask] == data_obj.y[mask]).float().mean().item()
            tr, va, te = acc(data_obj.train_mask), acc(data_obj.val_mask), acc(data_obj.test_mask)
            if va > best_val:
                best_val, best_test_at_val = va, te
                best_model = copy.deepcopy(model.state_dict())  # save best model state
                best_epoch = epoch

        if epoch % verbose_every == 0 or epoch == 1:
            print(f"Epoch {epoch:03d} | Loss {loss:.4f} | Train {tr:.3f} | Val {va:.3f} | Test {te:.3f} "
                  f"(Best Test@Val={best_test_at_val:.3f})")
    return 100*best_val, 100*best_test_at_val, best_model




# --- utils: permute a GROUP of features together -----------------------------
def permute_feature_block(data, feature_indices, seed=0):
    """
    Jointly permutes a set of feature columns across nodes (same permutation for all
    columns in 'feature_indices'), keeping other columns unchanged.

    Args:
        data: PyG Data object (expects .x of shape [num_nodes, num_features])
        feature_indices: list[int] or 1D torch.LongTensor of column indices to permute
        seed: int, RNG seed
    Returns:
        A new Data object with the selected columns permuted jointly.
    """
    data_perm = data.clone()
    device = data.x.device

    if not torch.is_tensor(feature_indices):
        feature_indices = torch.as_tensor(feature_indices, dtype=torch.long, device=device)

    g = torch.Generator(device=device).manual_seed(seed)
    perm = torch.randperm(data.num_nodes, generator=g, device=device)

    # Joint permutation: preserve within-group correlations, break alignment to nodes
    data_perm.x[:, feature_indices] = data.x[perm][:, feature_indices]
    return data_perm





def pairwise_mi_matrix(X, random_state=0, n_neighbors=3, subsample=None, standardize=False):
    """
    Compute a symmetric pairwise MI matrix between columns of X (nodes x features).

    Args:
        X: np.ndarray [N, F]
        random_state: int
        n_neighbors: kNN param for MI estimator
        subsample: int or None. If set, randomly sample this many rows (nodes) for speed.
        standardize: bool. If True, z-score each column before MI (can help).

    Returns:
        MI: np.ndarray [F, F], symmetric, zeros on diagonal.
    """
    rng = np.random.RandomState(random_state)
    X_ = X
    if subsample is not None and subsample < X.shape[0]:
        idx = rng.choice(X.shape[0], size=subsample, replace=False)
        X_ = X[idx]
    if standardize:
        mu = X_.mean(axis=0, keepdims=True)
        sd = X_.std(axis=0, keepdims=True) + 1e-8
        X_ = (X_ - mu) / sd

    F = X_.shape[1]
    MI = np.zeros((F, F), dtype=float)

    # compute upper triangle; MI is symmetric for our purposes
    for i in range(F):
        y = X_[:, i]
        for j in range(i+1, F):
            xj = X_[:, [j]]  # shape (N,1)
            mij = mutual_info_regression(xj, y, n_neighbors=n_neighbors, random_state=random_state)[0]
            MI[i, j] = MI[j, i] = mij
    return MI


def mi_feature_groups_from_matrix(MI, n_clusters=20):
    """
    Cluster features using AgglomerativeClustering on a distance from MI.

    Args:
        MI: [F, F] mutual information matrix
        n_clusters: number of groups to form

    Returns:
        feature_groups: list[list[int]]
    """
    # Convert to a distance in [0,1]: D = 1 - MI_norm
    max_mi = np.max(MI[np.triu_indices_from(MI, k=1)]) + 1e-12
    MI_norm = MI / max_mi
    D = 1.0 - MI_norm
    np.fill_diagonal(D, 0.0)

    clustering = AgglomerativeClustering(
        n_clusters=n_clusters,
        affinity='precomputed',
        linkage='average'
    ).fit(D)

    labels = clustering.labels_
    groups = []
    for g in range(n_clusters):
        idxs = np.where(labels == g)[0].tolist()
        if idxs:
            groups.append(idxs)
    return groups


from scipy.spatial.distance import pdist, squareform

def feature_distance_matrix(X, method="correlation"):
    """
    Compute a symmetric distance matrix between features.

    Args:
        X: np.ndarray [num_nodes, num_features]
        method: 'correlation' or 'euclidean'
    Returns:
        D: np.ndarray [num_features, num_features], symmetric distance matrix
    """
    if method == "correlation":
        # Correlation between features (columns)
        corr = np.corrcoef(X, rowvar=False)  # shape: (F, F)
        corr = np.nan_to_num(corr)           # handle NaNs if variance is zero
        D = 1 - np.abs(corr)                 # distance in [0, 2]
        np.fill_diagonal(D, 0.0)

    elif method == "euclidean":
        # Pairwise Euclidean distances between columns
        D = squareform(pdist(X.T, metric='euclidean'))
        # Normalize to [0, 1]
        D /= D.max()
        np.fill_diagonal(D, 0.0)
    else:
        raise ValueError("method must be 'correlation' or 'euclidean'")

    return D


def feature_groups_from_distance(D, n_clusters=20):
    """
    Cluster features using AgglomerativeClustering given a distance matrix.

    Args:
        D: [F, F] distance matrix (smaller = more similar)
        n_clusters: desired number of groups
    Returns:
        feature_groups: list of lists of feature indices
    """
    clustering = AgglomerativeClustering(
        n_clusters=n_clusters,
        metric='precomputed',
        linkage='average'
    ).fit(D)

    labels = clustering.labels_
    groups = []
    for g in range(n_clusters):
        idxs = np.where(labels == g)[0].tolist()
        if idxs:
            groups.append(idxs)
    return groups



def load_dataset(DS, device, split_style='train_rest'):
    """
    DS: one of {'Cora','CiteSeer','PubMed','Photo','Computers',
                'Cornell','Texas','Wisconsin','Chameleon','Squirrel','Actor'}
    split_style: 'per_class' or 'train_rest'
    """
    norm = T.NormalizeFeatures()

    # Helper: choose a split transform
    def make_split_transform():
        return T.RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2)

    if DS in {'Cora', 'CiteSeer', 'PubMed'}:
        root = os.path.join('.', 'data', 'Planetoid')
        dataset = Planetoid(root=root, name=DS, transform=norm)
        # split_tf = make_split_transform()
        # dataset = Planetoid(root=root, name=DS, transform=T.Compose([norm, split_tf]))
        data = dataset[0]  # masks already provided
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes

    elif DS in {'Photo', 'Computers'}:
        root = os.path.join('.', 'data', 'Amazon')
        split_tf = make_split_transform()
        dataset = Amazon(root=root, name=DS, transform=T.Compose([norm, split_tf]))
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes

    elif DS in {'Cornell', 'Texas', 'Wisconsin'}:
        # WebKB (strongly heterophilic)
        root = os.path.join('.', 'data', 'WebKB')
        split_tf = make_split_transform()
        dataset = WebKB(root=root, name=DS, transform=T.Compose([norm, split_tf]))
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes

    elif DS in {'chameleon', 'squirrel'}:
        # WikipediaNetwork (heterophilic). You can try geom_gcn_preprocess=True as an alternative.
        root = os.path.join('.', 'data', 'WikipediaNetwork')
        split_tf = make_split_transform()
        dataset = WikipediaNetwork(
            root=root,
            name=DS,
            transform=T.Compose([norm, split_tf]),
            geom_gcn_preprocess=False,
        )
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes

    elif DS == 'Actor':
        root = os.path.join('.', 'data', 'Actor')
        split_tf = make_split_transform()
        dataset = Actor(root=root, transform=T.Compose([norm, split_tf]))
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes
    elif DS in {'CS', 'Physics'}:
        root = os.path.join('.', 'data', 'Coauthor')
        split_tf = make_split_transform()
        dataset = Coauthor(root=root, name=DS, transform=T.Compose([norm, split_tf]))
        # dataset = Coauthor(root=root, name=DS, transform=norm)
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes
    elif DS == 'WikiCS':
        root = os.path.join('.', 'data', 'WikiCS')
        split_tf = make_split_transform()
        dataset = WikiCS(root=root, transform=T.Compose([norm, split_tf]))
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes
    elif DS in {'roman_empire',
            'amazon_ratings',
            'minesweeper',
            'tolokers',
            'questions',}:
        root = os.path.join('.', 'data', 'HeterophilousGraph')
        split_tf = make_split_transform()
        dataset = HeterophilousGraphDataset(root=root, name=DS, transform=T.Compose([norm, split_tf]))
        data = dataset[0]
        data = data.to(device)
        num_node_features = dataset.num_node_features
        num_classes = dataset.num_classes
    else:
        with open('simData/node_class_' + DS + '.pkl', 'rb') as f:
            dataset = pickle.load(f)
        data = dataset['data']
        num_node_features = data.x.shape[1] 
        num_classes = data.y.max().item() + 1  # assuming y is a tensor of class indices

        case = DS.split('_')[1]
        Y_A = True if int(case[0]) else False 
        X_A = True if int(case[1]) else False
        Y_X = True if int(case[2]) else False
        print(f"Dataset: {DS}, \n Y_A: {Y_A}, X_A: {X_A}, Y_X: {Y_X}")

    
    return dataset, data, num_node_features, num_classes



def apply_feature_mask(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    # x: (N, F), mask: (F,)
    return x * mask.unsqueeze(0)  # broadcast over nodes

def permute_single_feature_on_tensor(x: torch.Tensor, feat_i: int, seed: int) -> torch.Tensor:
    # returns a NEW tensor with column feat_i permuted across nodes
    g = torch.Generator(device=x.device)
    g.manual_seed(seed)
    idx = torch.randperm(x.size(0), generator=g, device=x.device)
    x_perm = x.clone()
    x_perm[:, feat_i] = x_perm[idx, feat_i]
    return x_perm



def get_adj_np(data):
    """Return dense adjacency (numpy) for single-graph data."""
    # If data has a 'batch' attr with single graph it still works; default otherwise
    if hasattr(data, 'batch') and data.batch is not None:
        A = to_dense_adj(data.edge_index, batch=data.batch).squeeze(0)
    else:
        A = to_dense_adj(data.edge_index).squeeze(0)
    return A.cpu().numpy()




def get_dense_x_split(data, mask):
    """Return dense X[mask] as numpy (handles sparse)."""
    if hasattr(data.x, 'is_sparse') and data.x.is_sparse:
        X = data.x.to_dense()
    else:
        X = data.x
    return X[mask].cpu().numpy()



def _standardize_cols(X, eps=1e-8):
    """Z-score each feature column for stability."""
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True)
    return (X - mu) / (sd + eps)

def compute_hsim_euc_per_feature(data):
    """
    h_sim-euc per feature: higher = more homophilous.
    We z-score features, then for each feature m compute
      s_m = - mean_{(u,v) in E} (x_u[m] - x_v[m])^2
    and min-max scale s across features to [0,1] for comparability.
    """
    # X (CPU, dense)
    X = data.x.to_dense().cpu().numpy() if getattr(data.x, "is_sparse", False) else data.x.cpu().numpy()
    Xz = _standardize_cols(X)

    # edges (u,v)
    ei = data.edge_index.cpu().numpy()
    u = ei[0]; v = ei[1]

    M = Xz.shape[1]
    scores = np.zeros(M, dtype=float)
    # vectorized over edges, loop over features
    diffs = Xz[u] - Xz[v]        # shape: (|E|, M)
    mse_per_feat = np.mean(diffs * diffs, axis=0)
    scores = -mse_per_feat       # larger (less squared diff) => more homophily

    # min-max to [0,1]
    mn, mx = float(scores.min()), float(scores.max())
    if mx > mn:
        scores = (scores - mn) / (mx - mn)
    else:
        scores[:] = 0.0
    return scores  # length M

def compute_hattr_per_feature(data):
    """
    h_attr per feature: Pearson correlation between each feature and its
    (row-normalized) neighbor average (D^{-1} A X) on the full graph.
    Higher => stronger assortativity of that feature over edges.
    """
    # X (CPU, dense)
    X = data.x.to_dense().cpu().numpy() if getattr(data.x, "is_sparse", False) else data.x.cpu().numpy()
    N, M = X.shape

    # Build sparse adjacency (undirected OK if data.edge_index lists both directions)
    ei = data.edge_index.cpu().numpy()
    A = sp.coo_matrix((np.ones(ei.shape[1], dtype=np.float32), (ei[0], ei[1])), shape=(N, N)).tocsr()

    deg = np.asarray(A.sum(axis=1)).reshape(-1)
    deg_safe = np.maximum(deg, 1.0)
    # Row-normalized adjacency: D^{-1} A
    Dinv = sp.diags(1.0 / deg_safe)
    Arow = Dinv @ A
    X_neigh = Arow @ X       # neighbor average (zeros for isolated nodes)

    # mask out isolated nodes for correlation stability
    mask = deg > 0
    if mask.sum() < 2:
        # fallback: zero scores
        return np.zeros(M, dtype=float)

    Xm = X[mask]
    Xnm = X_neigh[mask]

    # Pearson r per feature
    Xm_z = (Xm - Xm.mean(axis=0, keepdims=True))
    Xnm_z = (Xnm - Xnm.mean(axis=0, keepdims=True))
    num = np.sum(Xm_z * Xnm_z, axis=0)
    den = np.sqrt(np.sum(Xm_z**2, axis=0) * np.sum(Xnm_z**2, axis=0)) + 1e-12
    r = num / den              # in [-1, 1]
    # Map to [0,1] (optional but keeps the “higher is better” convention)
    r01 = 0.5 * (r + 1.0)
    return r01



def compute_hge_per_feature(data, nbins=30, sim_type="cos"):
    """
    h_GE per feature via (1 - normalized entropy) of edge-wise similarity histogram.
    Higher -> neighbors are consistently similar on that feature (more homophily).

    sim_type:
      - "cos": cosine similarity of 1D feature values (after z-scoring) -> dot(x_u, x_v)
      - "euc": convert absolute difference to similarity as sim = exp(-(abs(x_u-x_v)))
               (monotone mapping; you can also use squared diff)
    """
    # X dense on CPU
    X = data.x.to_dense().cpu().numpy() if getattr(data.x, "is_sparse", False) else data.x.cpu().numpy()
    N, M = X.shape

    ei = data.edge_index.cpu().numpy()
    u, v = ei[0], ei[1]

    # z-score each feature for stability (esp. for cos mode)
    mu = X.mean(axis=0, keepdims=True)
    sd = X.std(axis=0, keepdims=True) + 1e-8
    Xz = (X - mu) / sd

    # Prepare similarities shape (|E|, M)
    if sim_type == "cos":
        # For 1D features, cosine reduces to sign * normalized product; with z-scoring it's fine to use product
        sims = Xz[u] * Xz[v]  # (|E|, M)
        # clip to a reasonable range before binning
        sims = np.clip(sims, -5.0, 5.0)
        lo, hi = -5.0, 5.0
    elif sim_type == "euc":
        diffs = np.abs(X[u] - X[v])
        sims = np.exp(-diffs)  # in (0,1]
        lo, hi = 0.0, 1.0
    else:
        raise ValueError(f"Unknown sim_type={sim_type}")

    # Histogram per feature, compute normalized entropy, then invert
    bins = np.linspace(lo, hi, nbins + 1, dtype=float)
    scores = np.zeros(M, dtype=float)
    logk = np.log(nbins + 1e-12)  # for normalization to [0,1]

    # Vectorized histogram per feature
    # We'll loop features to keep memory modest and avoid giant 2D bincounts
    for m in range(M):
        h, _ = np.histogram(sims[:, m], bins=bins)
        p = h.astype(float)
        s = p.sum()
        if s <= 0:
            scores[m] = 0.0
            continue
        p /= s
        # entropy in nats
        ent = -np.sum(p[p > 0] * np.log(p[p > 0]))
        ent_norm = ent / (logk if logk > 0 else 1.0)
        scores[m] = 1.0 - float(np.clip(ent_norm, 0.0, 1.0))  # higher = more homophily

    return scores  # length M in [0,1]


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GINConv, GATConv


class BaselineArch(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim, n_layers=2,
                 act=nn.ReLU(), l_act=nn.Identity(), bias=True, dropout=0.0):
        super().__init__()
        assert n_layers >= 1, "n_layers must be >= 1"
        self.in_d = in_dim
        self.hid_d = hid_dim
        self.out_d = out_dim
        self.act = act
        self.l_act = l_act
        self.dropout = nn.Dropout(p=dropout)

        self.convs = self._create_conv_layers(n_layers, bias)
        self.n_params = sum(p.numel() for p in self.parameters() if p.requires_grad)

    # Subclasses must implement this:
    def _create_conv_layers(self, n_layers: int, bias: bool) -> nn.ModuleList:
        raise NotImplementedError

    # PyG-style forward: (x, edge_index)
    def forward(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = self.act(x)
            x = self.dropout(x)
        x = self.convs[-1](x, edge_index)
        return self.l_act(x)


class MLP(nn.Module):
    """Simple MLP used inside GINConv."""
    def __init__(self, in_dim, hid_dim, out_dim, n_layers=2, dropout=0.0, bias=True,
                 act=nn.ReLU(), l_act=nn.Identity()):
        super().__init__()
        self.act = act
        self.l_act = l_act
        self.dropout = nn.Dropout(dropout)
        layers = []
        if n_layers <= 1:
            layers.append(nn.Linear(in_dim, out_dim, bias=bias))
        else:
            layers.append(nn.Linear(in_dim, hid_dim, bias=bias))
            for _ in range(n_layers - 2):
                layers.append(nn.Linear(hid_dim, hid_dim, bias=bias))
            layers.append(nn.Linear(hid_dim, out_dim, bias=bias))
        self.layers = nn.ModuleList(layers)

    def forward(self, x):
        for lin in self.layers[:-1]:
            x = self.act(lin(x))
            x = self.dropout(x)
        x = self.layers[-1](x)
        return x


class GIN(BaselineArch):
    def __init__(self, in_dim, hid_dim, out_dim, n_layers=2, act=nn.ReLU(), l_act=nn.Identity(),
                 bias=True, dropout=0.0, aggregator='sum', mlp_layers=2,
                 init_eps=0.0, learn_eps=False):
        self.aggregator = aggregator          # 'sum' | 'mean' | 'max'
        self.mlp_layers = mlp_layers
        self.init_eps = init_eps
        self.learn_eps = learn_eps
        super().__init__(in_dim, hid_dim, out_dim, n_layers, act, l_act, bias, dropout)

    def _create_conv_layers(self, n_layers: int, bias: bool) -> nn.ModuleList:
        convs = nn.ModuleList()
        if n_layers == 1:
            nn_inner = MLP(self.in_d, self.hid_d, self.out_d, n_layers=self.mlp_layers,
                           dropout=0.0, bias=bias, act=self.act, l_act=None)
            convs.append(GINConv(nn_inner, eps=self.init_eps, train_eps=self.learn_eps,
                                 aggr=self.aggregator))
            return convs

        # first (in -> hid)
        nn_inner = MLP(self.in_d, self.hid_d, self.hid_d, n_layers=self.mlp_layers,
                       dropout=0.0, bias=bias, act=self.act, l_act=None)
        convs.append(GINConv(nn_inner, eps=self.init_eps, train_eps=self.learn_eps,
                             aggr=self.aggregator))
        # middle hid -> hid
        for _ in range(n_layers - 2):
            nn_inner = MLP(self.hid_d, self.hid_d, self.hid_d, n_layers=self.mlp_layers,
                           dropout=0.0, bias=bias, act=self.act, l_act=None)
            convs.append(GINConv(nn_inner, eps=self.init_eps, train_eps=self.learn_eps,
                                 aggr=self.aggregator))
        # final hid -> out
        nn_inner = MLP(self.hid_d, self.hid_d, self.out_d, n_layers=self.mlp_layers,
                       dropout=0.0, bias=bias, act=self.act, l_act=None)
        convs.append(GINConv(nn_inner, eps=self.init_eps, train_eps=self.learn_eps,
                             aggr=self.aggregator))
        return convs


class GAT(BaselineArch):
    """
    PyTorch Geometric GAT.
    - Hidden layers: concat=True -> output dim = heads * hid_dim
    - Last layer: concat=False, heads=1 -> output dim = out_dim
    """
    def __init__(self, in_dim, hid_dim, out_dim, n_layers=2, num_heads=4,
                 act=nn.ELU(), l_act=nn.Identity(), bias=True, dropout=0.0,
                 attn_dropout=0.6, add_self_loops=True):
        self.num_heads = num_heads
        self.attn_dropout = attn_dropout
        self.add_self_loops = add_self_loops
        super().__init__(in_dim, hid_dim, out_dim, n_layers, act, l_act, bias, dropout)

    def _create_conv_layers(self, n_layers: int, bias: bool) -> nn.ModuleList:
        convs = nn.ModuleList()
        if n_layers == 1:
            convs.append(GATConv(self.in_d, self.out_d, heads=1, concat=False,
                                 dropout=self.attn_dropout, add_self_loops=self.add_self_loops,
                                 bias=bias))
            return convs

        # first (in -> hid), concat=True
        convs.append(GATConv(self.in_d, self.hid_d, heads=self.num_heads, concat=True,
                             dropout=self.attn_dropout, add_self_loops=self.add_self_loops,
                             bias=bias))
        in_dim_next = self.hid_d * self.num_heads

        # middle layers (hid -> hid), concat=True
        for _ in range(n_layers - 2):
            convs.append(GATConv(in_dim_next, self.hid_d, heads=self.num_heads, concat=True,
                                 dropout=self.attn_dropout, add_self_loops=self.add_self_loops,
                                 bias=bias))
            in_dim_next = self.hid_d * self.num_heads

        # final (hid -> out), concat=False, heads=1
        convs.append(GATConv(in_dim_next, self.out_d, heads=1, concat=False,
                             dropout=self.attn_dropout, add_self_loops=self.add_self_loops,
                             bias=bias))
        return convs
