import math
import operator
from typing import List, Optional, Tuple

import pytorch_lightning as pl
import torch
import torch_geometric as pyg
from ogb.lsc import MAG240MDataset
from ogb.nodeproppred import PygNodePropPredDataset
from torch.utils.data import Dataset
from torch_geometric import utils
from torch_geometric.data import Batch, Data, InMemoryDataset
from torch_geometric.datasets import TUDataset, WikipediaNetwork
from torch_geometric.loader import DataLoader, NeighborLoader, RandomNodeLoader
from torch_geometric.utils import to_undirected
from sklearn.model_selection import StratifiedShuffleSplit

from manifold_transformers import nc_datasets

# Device to do move out of here.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def create_datamodule_from_config(config) -> pl.LightningDataModule:
    """
    Create appropriate datamodule based on dataset name in config.

    Args:
        config: Dictionary containing dataset configuration

    Returns:
        Configured datamodule with inferred model dimensions
    """
    dataset_name = config["dataset"]

    if dataset_name == "ogbn-arxiv":
        # Create arxiv data module using the helper function
        datamodule = construct_arxiv_datamodule(config)
        # Inferring model dimensions based on data
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = int(datamodule.train_data.y.max().item()) + 1
    elif dataset_name == "arxiv-year":
        datamodule = construct_arxiv_year_datamodule(config)
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = int(datamodule.train_data.y.max().item()) + 1
    elif dataset_name in ["ogbn-mag"]:
        # Create MAG data module
        datamodule = construct_mag_datamodule(config)
        # Inferring model dimensions based on data
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = datamodule.num_classes
    elif dataset_name in ["snap-patents"]:
        # Create snap-patents data module
        datamodule = construct_snap_patents_datamodule(config)
        # Inferring model dimensions based on data
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = datamodule.num_classes
    elif dataset_name == "reddit-binary":
        datamodule = construct_reddit_datamodule(config)
        sample_batch = next(iter(datamodule.train_dataloader()))
        config["in_features"] = sample_batch.x.size(1)
        config["num_classes"] = 2
    elif dataset_name == "chameleon":
        datamodule = construct_chameleon_datamodule(config)
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = datamodule.num_classes
    elif dataset_name == "pokec":
        datamodule = construct_pokec_datamodule(config)
        config["in_features"] = datamodule.train_data.x.size(1)
        config["num_classes"] = datamodule.num_classes
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")

    return datamodule


def _downsample_data(datas: List[Data] | Data, fraction: float):
    """
    Returns a new dataset with a fraction of nodes in each graph.
    """
    is_list = isinstance(datas, list)
    if not is_list:
        datas = [datas]

    new_data = []
    for data in datas:
        num_nodes = int(data.num_nodes * fraction)
        if num_nodes == 0:
            continue
        indices = torch.randperm(data.num_nodes)[:num_nodes]
        subsampled_data = _subsample_graph_n_nodes(data, indices=indices)
        new_data.append(subsampled_data)

    if not is_list:
        return new_data[0]
    return new_data


def _subsample_graph_n_nodes(
    data: Data, indices: Optional[torch.Tensor] = None
) -> Data:
    """
    Randomly sample a subset of nodes and return the induced subgraph.
    Commonly used for graph splitting or downsampling.

    data: Data object
        The original graph.
    num_nodes: int
        The number of nodes to sample.
    indices: Optional[torch.Tensor] = None
        The indices of the nodes to sample. If None, sample randomly using num_nodes
    Returns
    -------
    data_sub
        A PyG Data object with num_nodes nodes.
    """
    # Randomly select a subset of nodes and extract the induced subgraph.
    if indices is not None and not isinstance(indices, torch.Tensor):
        indices = torch.tensor(indices, dtype=torch.long)

    edge_index, _ = utils.subgraph(
        indices, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes
    )
    # Handle labels: if node-level labels exist, index them; otherwise preserve graph-level label
    y_sub = None
    if hasattr(data, "y") and data.y is not None:
        y_tensor = data.y
        if (
            isinstance(y_tensor, torch.Tensor)
            and y_tensor.dim() > 0
            and y_tensor.size(0) == data.num_nodes
        ):
            y_sub = y_tensor[indices]
        else:
            y_sub = y_tensor

    new_graph = Data(
        x=data.x[indices],
        edge_index=edge_index,
        y=y_sub,
        num_nodes=indices.numel(),
    )
    return new_graph


def _stratified_split_into_train_val_test(
    data: Data or List[Data],
    test_ratio: float, 
    val_ratio: float, 
    fold_idx: int, 
    k:int, seed=42
) -> Tuple[Data, Data, Data]:
    """
    Split a single graph into train, validation, and test sets with k-fold stratified split.

    Args:
        data: The PyG Data object to split
        test_ratio: Ratio of nodes to use for testing
        val_ratio: Ratio of nodes to use for validation
        fold_idx: Index of the fold to use for splitting (0 to k-1)
        k: Number of folds for stratified splitting
        seed: Random seed for reproducibility
    """
        
    sss = StratifiedShuffleSplit(n_splits=k, test_size=test_ratio, random_state=seed)
    if isinstance(data, list):
        labels = torch.tensor([graph.y.item() if graph.y.dim() == 0 else graph.y[0].item() for graph in data])
        idx = torch.arange(len(data))
    else:
        labels = data.y
        idx = torch.arange(data.num_nodes)
    splits = list(sss.split(idx, labels))
    train_val_indices, test_indices = splits[fold_idx]

    sss2 = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=seed)
    if isinstance(data, list):
        labels2 = labels[train_val_indices]
        idx2 = torch.arange(len(train_val_indices))
    else:
        labels2 = data.y[train_val_indices]
        idx2 = torch.arange(len(train_val_indices))
    train_indices, val_indices = next(sss2.split(idx2, labels2))

    if isinstance(data, list):
        train_data = [data[i] for i in train_val_indices[train_indices]]
        val_data = [data[i] for i in train_val_indices[val_indices]]
        test_data = [data[i] for i in test_indices]
    else:
        train_data = _subsample_graph_n_nodes(data, indices=train_val_indices[train_indices])
        val_data = _subsample_graph_n_nodes(data, indices=train_val_indices[val_indices])
        test_data = _subsample_graph_n_nodes(data, indices=test_indices)
    
    return train_data, val_data, test_data


def _memory_downsample_data(data: Data, max_nodes: int):
    """
    Downsample the graph to have at most max_nodes nodes if it exceeds that size.
    """
    if data.num_nodes > max_nodes:
        print(f"Downsampling graph from {data.num_nodes} to {max_nodes} nodes to fit memory constraints.")
        indices = torch.randperm(data.num_nodes)[:max_nodes]
        data = _subsample_graph_n_nodes(data, indices=indices)
    return data


class SingleGraphGrowingNodeDataset(Dataset):
    """
    Generates induced subgraphs from a single graph for transferability evaluation.

    Args:
        test_data: Graph to sample from when creating evaluation subgraphs.
        test_fractions: Fractions of nodes to retain for each growing-graph evaluation size.
        test_num_batches_per_size: Number of random subgraphs to draw for each fraction.
        seed: Optional manual seed for reproducible sampling.

    Returns:
        List[Tuple[int, List[Data]]]: For each requested fraction, a tuple containing the
        sampled node count and the list of sampled subgraphs.
    """
    def __init__(
        self,
        test_data: Data,
        test_fractions: Optional[List[float]] = None,
        test_num_batches_per_size: int = 5,
        seed: Optional[int] = None,
    ):
        self.test_data = test_data
        self.test_fractions = test_fractions or []
        if test_num_batches_per_size <= 0:
            raise ValueError("test_num_batches_per_size must be a positive integer.")
        self.test_num_batches_per_size = test_num_batches_per_size
        self._generator = torch.Generator()
        if seed is not None:
            self._generator.manual_seed(seed)

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        if idx != 0:
            raise IndexError("This dataset contains only one graph sample.")

        subsamples: List[Data] = []
        for frac in self.test_fractions:
            n = int(self.test_data.num_nodes * frac)
            if n <= 0:
                continue
            for _ in range(self.test_num_batches_per_size):
                # Randomly sample n nodes
                node_idx = torch.randperm(
                    self.test_data.num_nodes,
                    generator=self._generator,
                )[:n]
                subsample = _subsample_graph_n_nodes(
                    self.test_data,
                    indices=node_idx,
                )
                subsample.sample_fraction = float(frac)
                subsample.sample_size = int(subsample.num_nodes)
                subsamples.append(subsample)

        return subsamples



def _ensure_undirected_graph(data):
    if data is None:
        return data

    if isinstance(data, list):
        return [_ensure_undirected_graph(graph) for graph in data]

    if getattr(data, "edge_index", None) is None:
        return data

    data.edge_index = to_undirected(data.edge_index, num_nodes=data.num_nodes)
    return data


class MultiGraphGrowingNodeDataset(Dataset):
    """
    Subsamples an increasing number of graphs from the dataset for evaluation.
    """

    def __init__(self, test_data: list, test_fractions: list):
        self.test_data = test_data
        self.fractions = test_fractions

    def __len__(self):
        return 1  # single evaluation batch

    def __getitem__(self, idx):
        if idx != 0:
            raise IndexError("This dataset contains only one graph sample.")

        samples: List[Data] = []

        # Include full test graphs as baseline samples.
        for graph in self.test_data:
            graph_clone = graph.clone()
            graph_clone.sample_fraction = 1.0
            graph_clone.sample_size = int(getattr(graph_clone, "num_nodes", 0))
            samples.append(graph_clone)

        for frac in self.fractions:
            downsampled = _downsample_data(self.test_data, frac)
            if not isinstance(downsampled, list):
                downsampled = [downsampled]
            for graph in downsampled:
                graph.sample_fraction = float(frac)
                graph.sample_size = int(getattr(graph, "num_nodes", 0))
                samples.append(graph)

        return samples


class SingleGraphDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_data: Data,
        val_data: Data,
        test_data: Data,
        num_workers: int = 1,
        is_neighbor_loader: bool = True,
        num_hops: int | None = None,
        batch_size: int = 256,
        neighbor_config: List[int] = None,
        seed: int = 0,
        test_fractions: Optional[List[float]] = None,
        test_num_batches_per_size: int = 5,
    ):
        """
        PTL DataModule for single-graph datasets. Originally designed for
        use with `GrowingGraphDataset`-based testing, should work with regular testing too.

        Args:
            train_data: PyG single graph Data object for training.
            val_data: PyG single graph Data object for validation.
            test_data: PyG single graph Data object for testing.
            num_workers: Number of workers for data loading.
            neighbor_loader: Whether to use NeighborLoader (True) or DataLoader (False) #TO DO redundant.
            num_hops: Number of hops to sample for NeighborLoader.
            batch_size: Batch size for NeighborLoader (default: 256 to match vanilla).
            neighbor_config: Neighbor sampling config (default: [20, 15, 10, 5] to match vanilla).
            seed: Random seed used for reproducible sampling.
            test_fractions: Fractions of nodes to use when generating evaluation subgraphs.
            test_num_batches_per_size: Number of subsamples to draw for each fraction.
        """
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_data = test_data
        self.num_workers = num_workers
        self.is_neighbor_loader: bool = is_neighbor_loader
        if isinstance(self.is_neighbor_loader, str):  # I wish we had a typed language.
            self.is_neighbor_loader = self.is_neighbor_loader == "neighbor"
        self.num_hops = num_hops
        self.batch_size = batch_size
        self.neighbor_config = neighbor_config or [20, 15, 10, 5]
        self.seed = 0 if seed is None else int(seed)
        self._train_random_dataset: Optional["RandomSubgraphDataset"] = None
        self.test_fractions = test_fractions or []
        self.test_num_batches_per_size = test_num_batches_per_size

        self.test_dataset = SingleGraphGrowingNodeDataset(
            test_data=self.test_data,
            test_fractions=self.test_fractions,
            test_num_batches_per_size=self.test_num_batches_per_size,
            seed=seed,
        )

    def train_dataloader(self):
        if self.is_neighbor_loader:
            if self.num_hops is None:
                num_neighbors = self.neighbor_config
            else:
                num_neighbors = [self.num_hops] * len(self.neighbor_config)
            return NeighborLoader(
                data=self.train_data,
                num_neighbors=num_neighbors,
                batch_size=self.batch_size,  # Use configurable batch_size (default 256)
                shuffle=True,
                num_workers=self.num_workers,
                persistent_workers=self.num_workers > 0,
                pin_memory=True,
            )
        else:
            # Loads the whole graph into memory.
            if self._train_random_dataset is None:
                self._train_random_dataset = RandomSubgraphDataset(
                    data=self.train_data,
                    nodes_per_subgraph=self.batch_size,
                    seed=self.seed,
                )
            return DataLoader(
                dataset=self._train_random_dataset,
                batch_size=1,
                shuffle=True,
                num_workers=self.num_workers,
                persistent_workers=self.num_workers > 0,
                pin_memory=True,
            )

    def val_dataloader(self):
        return DataLoader(
            dataset=[self.val_data],
            batch_size=1,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
        )

    def test_dataloader(self):
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=1,  # SingleGraphGrowingNodeDataset returns one composite sample
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=operator.itemgetter(0),
        )


def _train_val_test_split_dataset(
    dataset: InMemoryDataset, train_frac=0.7, val_frac=0.1
) -> Tuple[InMemoryDataset, InMemoryDataset, InMemoryDataset]:
    """
    Split a PyG dataset into train, validation, and test sets.

    Args:
        dataset: PyG dataset to split
        train_frac: Fraction of data for training (default: 0.7)
        val_frac: Fraction of data for validation (default: 0.1)

    Returns:
        tuple: (train_dataset, val_dataset, test_dataset)
    """
    test_frac = 1 - train_frac - val_frac

    # Shuffle dataset
    dataset, dataset_idx = dataset.shuffle(return_perm=True)

    # Calculate indices for train, val, and test sets
    train_end = int(train_frac * len(dataset))
    val_end = train_end + int(val_frac * len(dataset))
    test_end = len(dataset)
    train_idx = dataset_idx[:train_end]
    val_idx = dataset_idx[train_end:val_end]
    test_idx = dataset_idx[val_end:test_end]

    train_dataset = dataset[train_idx]
    val_dataset = dataset[val_idx]
    test_dataset = dataset[test_idx]

    return train_dataset, val_dataset, test_dataset


class MultiGraphDataModule(pl.LightningDataModule):
    def __init__(
        self,
        train_data: Data,
        val_data: Optional[Data] = None,
        test_dataset: Optional[MultiGraphGrowingNodeDataset] = None,
        num_workers: int = 1,
        batch_size: int = 256,
    ):
        """
        PTL DataModule for multi-graph datasets.

        Args:
            train_data: PyG single graph Data object for training
            val_data: PyG single graph Data object for validation (optional)
            test_dataset: SingleGraphGrowingNodeDataset for testing (optional)
            num_workers: Number of workers for data loading
            batch_size: Batch size for NeighborLoader (default: 256 to match vanilla)
        """
        super().__init__()
        self.train_data = train_data
        self.val_data = val_data
        self.test_dataset = test_dataset
        self.num_workers = num_workers
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            dataset=self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
            pin_memory=True,
        )

    def val_dataloader(self):
        return DataLoader(
            dataset=self.val_data,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.num_workers > 0,
            pin_memory=True,
        )

    def test_dataloader(self):
        if self.test_dataset is None:
            return None
        return DataLoader(
            dataset=self.test_dataset,
            batch_size=1,  # MultiGraphGrowingNodeDataset returns one composite sample
            shuffle=False,
            num_workers=self.num_workers,
            collate_fn=operator.itemgetter(0),
        )


def construct_single_graph_datamodule(data: Data, config):
    # Create our own train, test split.
    print(f"Original graph: {data.num_nodes} nodes, {data.num_edges} edges")
    full_train_data, full_val_data, full_test_data = _stratified_split_into_train_val_test(data, config["test_ratio"], config["val_ratio"], config["fold_idx"], config["k_fold_stratified_split"])
    print(f"After split - Train: {full_train_data.num_nodes} nodes, Val: {full_val_data.num_nodes} nodes, Test: {full_test_data.num_nodes} nodes")
    
    # Make graphs smaller if required
    train_data = _memory_downsample_data(full_train_data, config["max_memory_nodes"])
    val_data = _memory_downsample_data(full_val_data, config["max_memory_nodes"])
    test_data = _memory_downsample_data(full_test_data, config["max_memory_nodes"])

    # Downsample train dataset if needed
    train_data = _downsample_data(full_train_data, fraction=config["train_downsample_fraction"])
    print(f"After downsampling - Train: {train_data.num_nodes} nodes, Val: {val_data.num_nodes} nodes, Test: {test_data.num_nodes} nodes")

    if config.get("force_undirected", False):
        train_data = _ensure_undirected_graph(train_data)
        val_data = _ensure_undirected_graph(val_data)
        test_data = _ensure_undirected_graph(test_data)

    for graph in (val_data, test_data):
        graph.sample_fraction = 1.0
        graph.sample_size = int(getattr(graph, "num_nodes", 0))

    config.setdefault("test_num_batches_per_size", 5)

    if "test_downsample_fractions" not in config:
        min_nodes = config.get("min_testing_nodes")
        max_nodes = config.get("max_testing_nodes")
        step = config.get("new_nodes_per_test")
        fractions: List[float] = []
        total_nodes = max(1, int(test_data.num_nodes))
        if (
            isinstance(min_nodes, int)
            and isinstance(max_nodes, int)
            and isinstance(step, int)
            and step > 0
        ):
            for nodes in range(min_nodes, max_nodes + 1, step):
                frac = max(0.0, min(1.0, float(nodes) / float(total_nodes)))
                if frac > 0.0:
                    fractions.append(frac)
        if not fractions:
            fractions = [1.0]
        config["test_downsample_fractions"] = sorted(set(fractions))

    ## Batch size: honor explicit value when provided, otherwise fall back to fraction of nodes
    num_train_nodes = train_data.num_nodes
    batch_size = config.get("batch_size")
    if batch_size is None:
        pct = config.get("train_size_pct")
        if pct is None:
            raise ValueError("train_size_pct must be provided for single-graph datasets.")
        if pct <= 0 or pct > 1:
            raise ValueError("--train_size_pct must be in the interval (0, 1].")
        batch_size = min(num_train_nodes, max(1, math.ceil(pct * num_train_nodes)))
        config["batch_size"] = batch_size
    else:
        batch_size = int(batch_size)
        if batch_size <= 0:
            raise ValueError("--batch_size must be a positive integer.")
        batch_size = min(batch_size, num_train_nodes)
        config["batch_size"] = batch_size
        config["train_size_pct"] = batch_size / num_train_nodes

    config["effective_batch_size"] = batch_size
    config["num_train_nodes"] = num_train_nodes

    num_classes = int(data.y.max().item()) + 1
    datamodule = SingleGraphDataModule(
        train_data=train_data,
        val_data=val_data,
        test_data=test_data,
        num_workers=config["num_workers"],
        is_neighbor_loader=config["dataloader_type"]=="neighbor",
        batch_size=batch_size,
        neighbor_config=config.get("neighbor_num_neighbors") or [20, 15, 10, 5],
        seed=config.get("seed", 0),
        test_fractions=config["test_downsample_fractions"],
        test_num_batches_per_size=config["test_num_batches_per_size"],
    )
    datamodule.num_classes = num_classes
    return datamodule


def construct_multi_graph_datamodule(data: List[Data], config):
    # Create our own train, test split.
    print(f"Original graph: {len(data)} graphs")
    data_list = [data[i] for i in range(len(data))]
    full_train_data, full_val_data, full_test_data = _stratified_split_into_train_val_test(data_list, config["test_ratio"], config["val_ratio"], config["fold_idx"], config["k_fold_stratified_split"])
    print(f"After split - Train: {len(full_train_data)} graphs, Val: {len(full_val_data)} graphs, Test: {len(full_test_data)} graphs")
    degrees = [g.num_edges / g.num_nodes for g in full_train_data]
    avg_degrees = sum(degrees) / len(degrees)
    print(f"Average degrees of train graphs: {avg_degrees:.2f}")

    # Add node features if missing
    for dataset in [full_train_data, full_val_data, full_test_data]:
        for i, graph in enumerate(dataset):
            if not hasattr(graph, 'x') or graph.x is None:
                graph.x = torch.ones((graph.num_nodes, 1), dtype=torch.float)

    # Downsample train dataset if needed
    if config["train_downsample_fraction"] < 1.0:
        train_data = []
        for graph in full_train_data:
            downsampled_graph = _downsample_data(graph, fraction=config["train_downsample_fraction"])
            train_data.append(downsampled_graph)
            degrees = [g.num_edges / g.num_nodes for g in train_data]
            avg_degrees = sum(degrees) / len(degrees)
    else:
        train_data = full_train_data
    print(f"After downsampling - Average degrees of downsampled train graphs: {avg_degrees:.2f}")

    if config.get("force_undirected", False):
        train_data = _ensure_undirected_graph(train_data)
        full_val_data = _ensure_undirected_graph(full_val_data)
        full_test_data = _ensure_undirected_graph(full_test_data)

    # Create test dataset
    test_dataset = MultiGraphGrowingNodeDataset(
        test_data=full_test_data,
        test_fractions=config["test_downsample_fractions"]
    )

    datamodule = MultiGraphDataModule(
        train_data=train_data,
        val_data=full_val_data,
        test_dataset=test_dataset,
        num_workers=config["num_workers"],
        batch_size=config["batch_size"],
    )
    return datamodule


def construct_reddit_datamodule(config) -> MultiGraphDataModule:
    original_data = TUDataset(root="data/TUDataset", name="REDDIT-BINARY")
    
    datamodule = construct_multi_graph_datamodule(original_data, config)
    return datamodule


def construct_mag_datamodule(config) -> SingleGraphDataModule:
    """
    Create MAG datamodule from config using unified interface

    Args:
        config: Dictionary containing configuration parameters

    Returns:
        SingleGraphDataModule instance (unified interface)
    """
    
    # Add safe globals to avoid torch 2.6 new default pickle error
    from torch_geometric.data.data import DataEdgeAttr,DataTensorAttr
    from torch_geometric.data.storage import GlobalStorage
        
    torch.serialization.add_safe_globals([DataEdgeAttr,DataTensorAttr,GlobalStorage])

    dataset = PygNodePropPredDataset(name="ogbn-mag")
    hetero_data = dataset[0]

    # Convert heterogeneous to homogeneous data (paper-only graph)
    paper_x = hetero_data.x_dict["paper"]
    paper_y = hetero_data.y_dict["paper"].squeeze()
    if ("paper", "cites", "paper") in hetero_data.edge_index_dict:
        edge_index = hetero_data.edge_index_dict[("paper", "cites", "paper")]
    else:
        edge_index = torch.empty((2, 0), dtype=torch.long)

    original_data = Data(
        x=paper_x,
        edge_index=edge_index,
        y=paper_y,
        num_nodes=paper_x.size(0),
        num_classes=(paper_y.max() + 1).item(),
    )

    datamodule = construct_single_graph_datamodule(original_data, config)
    return datamodule



def construct_snap_patents_datamodule(config) -> SingleGraphDataModule:
    """
    Create snap-patents data module from config using unified interface.
    TO DO: Generalize to any transductive dataset.
    """
    dataset: "nc_datasets.NCDataset" = nc_datasets.load_nc_dataset(
        "snap-patents",
        mat_path=None,
        data_dir="./data",
        download=True,
        url=None,
        drive_id="1ldh23TSY1PwXia6dU0MYcpyEgX-w3Hia",
        nclass=5,
    )
    graph_dict, labels = dataset[0]
    num_classes = len(labels.unique())
    num_nodes = len(labels)

    original_data = Data(
        x=graph_dict["node_feat"],
        edge_index=graph_dict["edge_index"],
        y=labels,
        num_nodes=num_nodes,
        num_classes=num_classes,
    )

    datamodule = construct_single_graph_datamodule(original_data, config)
    return datamodule


def construct_arxiv_year_datamodule(config) -> SingleGraphDataModule:
    dataset = nc_datasets.load_arxiv_year_dataset()

    graph_dict, labels = dataset[0]

    num_nodes = len(labels)
    num_classes = len(labels.unique())

    original_data = Data(
        x=graph_dict["node_feat"],
        edge_index=graph_dict["edge_index"],
        y=labels,
        num_nodes=num_nodes,
        num_classes=num_classes,
    )

    datamodule = construct_single_graph_datamodule(original_data, config)

    return datamodule


def construct_chameleon_datamodule(config) -> SingleGraphDataModule:
    dataset = WikipediaNetwork(root="./data", name="chameleon")
    data = dataset[0]

    original_data = Data(
        x=data.x,
        edge_index=data.edge_index,
        y=data.y,
        num_nodes=data.num_nodes,
        num_classes=int(data.y.max().item()) + 1,
    )

    datamodule = construct_single_graph_datamodule(original_data, config)
    return datamodule


def construct_pokec_datamodule(config) -> SingleGraphDataModule:
    dataset: "nc_datasets.NCDataset" = nc_datasets.load_nc_dataset(
        "pokec",
        data_dir="./data",
    )
    graph_dict, labels = dataset[0]
    num_classes = len(labels.unique())
    num_nodes = len(labels)

    original_data = Data(
        x=graph_dict["node_feat"],
        edge_index=graph_dict["edge_index"],
        y=labels,
        num_nodes=num_nodes,
        num_classes=num_classes,
    )

    datamodule = construct_single_graph_datamodule(original_data, config)
    return datamodule


# support functions - not used
def get_isolated_components(data):
    """Simple utility to get the isolated components of a graph"""
    isolated_components = []
    for i in range(data.num_nodes):
        if i not in data.edge_index:
            isolated_components.append(i)
    return isolated_components

class RandomSubgraphDataset(Dataset):
    """
    One-time precompute of uniform node-induced subgraphs by chunking a single random permutation.
    - Drops the remainder so __len__ == N // k exactly.
    - Uses the provided `_subsample_graph_n_nodes` for subgraph creation.
    """
    def __init__(self, data: Data, nodes_per_subgraph: int, seed: Optional[int] = 0):
        assert isinstance(data, Data), "Expect a torch_geometric.data.Data (single graph)."
        assert nodes_per_subgraph > 0 and nodes_per_subgraph <= data.num_nodes
        self.data = data
        self.N = data.num_nodes
        self.nodes_per_subgraph = int(nodes_per_subgraph)

        # Single randomized permutation (deterministic via seed)
        generator = torch.Generator(device="cpu")
        if seed is not None:
            generator = generator.manual_seed(int(seed))
        permutation = torch.randperm(self.N, generator=generator)

        # Split into equal-sized chunks; drop the remainder
        num_batches = self.N // self.nodes_per_subgraph
        self.node_chunks = list(torch.split(
            permutation[: num_batches * self.nodes_per_subgraph],
            self.nodes_per_subgraph,
        ))

        # Precompute subgraphs once via your helper
        self.subgraphs: List[Data] = []
        for idx in self.node_chunks:
            sub = _subsample_graph_n_nodes(self.data, indices=idx)
            # (Optional) keep original node ids for bookkeeping
            sub.n_id = idx.clone()
            self.subgraphs.append(sub)

    def __len__(self) -> int:
        return len(self.subgraphs)  # == N // k

    def __getitem__(self, i: int) -> Data:
        return self.subgraphs[i]

def make_loader(data: Data, nodes_per_subgraph: int, batch_size: int = 1, seed: int = 0) -> DataLoader:
    ds = RandomSubgraphDataset(data, nodes_per_subgraph, seed)
    # No shuffle needed — the dataset is already randomized by the permutation
    return DataLoader(ds, batch_size=batch_size, shuffle=False)
