import os.path as osp
import numpy as np

import torch
import torch_geometric.transforms as T
from torch_geometric.transforms import RandomNodeSplit, ToSparseTensor, NormalizeFeatures, Compose
from torch_geometric.datasets import Planetoid, Amazon, Coauthor, Reddit, WikiCS, PPI, WebKB, Actor

from ogb.nodeproppred import PygNodePropPredDataset
import torch_geometric.data.storage
from torch_geometric.data.data import DataTensorAttr

torch.serialization.add_safe_globals([DataTensorAttr])
torch.serialization.safe_globals([torch_geometric.data.data.DataEdgeAttr])


def get_dataset(name, root_dir='../data', normalize_features=True, transform=None, split='public'):
    """
    Load a node classification dataset with optional transforms and split handling.

    Supported:
    - Planetoid: Cora, Citeseer, Pubmed
    - Amazon: Photo, Computers
    - Coauthor: CS, Physics
    - WikiCS
    - Reddit
    - OGBN: ogbn-arxiv, ogbn-products, etc.

    Note: PPI is a graph-level classification dataset and returns a dict structure.
    """
    name_lower = name.lower()
    full_path = osp.join(root_dir, name_lower)

    # Compose transforms
    final_transform = []
    if normalize_features:
        final_transform.append(T.NormalizeFeatures())
    if transform:
        final_transform.append(transform)
    final_transform = T.Compose(final_transform) if final_transform else None

    # Compose transforms
    final_transform = []
    if normalize_features:
        final_transform.append(T.NormalizeFeatures())
    if transform:
        final_transform.append(transform)
    final_transform = T.Compose(final_transform) if final_transform else None

    # Planetoid datasets
    if name_lower in ['cora', 'citeseer', 'pubmed']:
        if split == 'fixed':
            dataset = get_planetoid_custom_split(name_lower, random_seed=42)
        elif split.startswith('custom'):
            if split == 'custom-60-20-20':
                dataset = get_planetoid_dataset_custom(name_lower, split='complete', split_ratios=(0.6, 0.2, 0.2))
            elif split == 'custom-50-25-25':
                dataset = get_planetoid_dataset_custom(name_lower, split='complete', split_ratios=(0.5, 0.25, 0.25))
            elif split == 'custom-70-15-15':
                dataset = get_planetoid_dataset_custom(name_lower, split='complete', split_ratios=(0.7, 0.15, 0.15))
            elif split == 'custom-80-10-10':
                dataset = get_planetoid_dataset_custom(name_lower, split='complete', split_ratios=(0.8, 0.1, 0.1))
            elif split == 'custom-90-5-5':
                dataset = get_planetoid_dataset_custom(name_lower, split='complete', split_ratios=(0.9, 0.05, 0.05))
            else:
                raise ValueError(f"Unsupported custom split: {split}")
        elif split == 'complete':
            dataset = Planetoid(full_path, name_lower.capitalize())
            dataset[0].train_mask.fill_(False)
            dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1
            dataset[0].val_mask.fill_(False)
            dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1
            dataset[0].test_mask.fill_(False)
            dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1
        elif split == 'public':
            dataset = get_planetoid_dataset(name_lower, transform=final_transform, split='public')
        elif split == 'geom-gcn':
            dataset = Planetoid(full_path, name_lower.capitalize(), transform=final_transform, split='geom-gcn')
            data = dataset[0]
            # Geom-GCN uses a different split with multiple masks
            if len(data.train_mask.shape) > 1:
                data.train_mask = data.train_mask[:, 0].bool()
                data.val_mask = data.val_mask[:, 0].bool()
                data.test_mask = data.test_mask[:, 0].bool()
                dataset.data = data
        elif split == 'random':
            dataset = Planetoid(full_path, name_lower.capitalize(), transform=final_transform, split='public')
            data = dataset[0]
            transform_split = RandomNodeSplit(num_train_per_class=20, num_val=500, num_test=1000, split='train_rest')
            data = transform_split(data)
            dataset.data = data
        else:
            dataset = Planetoid(full_path, name_lower.capitalize(), transform=final_transform, split=split)

    # Amazon datasets
    elif name_lower in ['photo', 'computers']:
        dataset = Amazon(full_path, name=name_lower.capitalize(), transform=final_transform)
        # Add train/val/test masks if not present
        data = dataset[0]
        if not hasattr(data, 'train_mask') or data.train_mask is None:
            transform_split = RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2)
            data = transform_split(data)
            dataset.data = data

    # Coauthor datasets
    elif name_lower in ['cs', 'physics']:
        dataset = Coauthor(full_path, name=name_lower.upper(), transform=final_transform)
        # Add train/val/test masks if not present
        data = dataset[0]
        if not hasattr(data, 'train_mask') or data.train_mask is None:
            transform_split = RandomNodeSplit(split='train_rest', num_val=0.1, num_test=0.2)
            data = transform_split(data)
            dataset.data = data

    # Reddit dataset
    elif name_lower == 'reddit':
        dataset = Reddit(full_path, transform=final_transform)
        # Reddit already has train/val/test masks

    # WikiCS dataset
    elif name_lower == 'wikics':
        dataset = WikiCS(full_path, transform=final_transform)
        # WikiCS has multiple train masks, use the first one
        data = dataset[0]
        if len(data.train_mask.shape) > 1:
            data.train_mask = data.train_mask[:, 0].bool()
            dataset.data = data




    elif name_lower in ['texas', 'wisconsin', 'cornell']:

        dataset = WebKB(root=full_path, name=name.capitalize(), transform=NormalizeFeatures())

        data = dataset[0]
        # Keep masks as 1D [N] - no unsqueeze needed

        for split in ['train', 'val', 'test']:

            mask = getattr(data, f'{split}_mask')

            if mask.dim() > 1:
                mask = mask[:, 0]  # Take first split if multi-split

            setattr(data, f'{split}_mask', mask.bool())  # Keep as [N]

        return dataset


    elif name_lower == 'actor':
        dataset = Actor(root=full_path, transform=NormalizeFeatures())
        data = dataset[0]
        # Actor has SINGLE fixed split → convert to [N, 1]
        for split in ['train', 'val', 'test']:
            mask = getattr(data, f'{split}_mask')
            setattr(data, f'{split}_mask', mask.unsqueeze(1).bool())  # [N] → [N, 1]
        return dataset
    # PPI dataset
    elif name_lower == 'ppi':
        # PPI is a graph-level dataset, return the split datasets directly
        print("\n" + "=" * 60)
        print("WARNING: PPI is a graph-level classification dataset!")
        print("It returns a dict with 'train', 'val', 'test' splits.")
        print("Each split contains multiple graphs (not a single graph with node masks).")
        print("This dataset is NOT compatible with node-level GNN training code.")
        print("=" * 60 + "\n")

        dataset = {
            'train': PPI(full_path, split='train', transform=final_transform),
            'val': PPI(full_path, split='val', transform=final_transform),
            'test': PPI(full_path, split='test', transform=final_transform)
        }

        print(f"\n{'=' * 50}")
        print(f"Dataset: {name.upper()} (Graph-level classification)")
        print(f"{'=' * 50}")
        print(f"Train graphs: {len(dataset['train'])}")
        print(f"Val graphs: {len(dataset['val'])}")
        print(f"Test graphs: {len(dataset['test'])}")
        print(f"Number of features: {dataset['train'].num_features}")
        print(f"Number of classes: {dataset['train'].num_classes}")
        print(f"Note: PPI is a multi-graph dataset. Use dataset['train'], dataset['val'], dataset['test']")
        print(f"{'=' * 50}\n")
        return dataset

    # OGB Node Property Prediction datasets+

    elif name_lower.startswith('ogbn-'):
        import os
        import shutil

        # Patch torch.load temporarily to avoid weights_only=True errors
        original_torch_load = torch.load

        def patched_torch_load(*args, **kwargs):
            kwargs['weights_only'] = False
            return original_torch_load(*args, **kwargs)

        torch.load = patched_torch_load

        # Allow DataEdgeAttr in pickle (OGB edge attributes)
        try:
            from torch_geometric.data.data import DataEdgeAttr
            torch.serialization.add_safe_globals([DataEdgeAttr])
        except Exception as e:
            print("Warning: DataEdgeAttr not registered:", e)

        # Load dataset WITHOUT transforms (we'll handle edge_index manually)
        try:
            dataset = PygNodePropPredDataset(name=name_lower, root=full_path, transform=None)
        except Exception as e:
            # Handle corrupted downloads
            if "BadZipFile" in str(e) or "not a zip file" in str(e):
                print("\n" + "=" * 60)
                print("ERROR: Corrupted OGB dataset download!")
                print("=" * 60)
                print(f"Delete directory: {full_path}")
                print("=" * 60 + "\n")

                if os.path.exists(full_path):
                    try:
                        shutil.rmtree(full_path)
                        print(f"Deleted corrupted directory: {full_path}")
                    except Exception as cleanup_error:
                        print(f"Failed to delete directory: {cleanup_error}")
                raise
            else:
                print(f"Error loading {name_lower}: {e}")
                raise
        finally:
            # Always restore torch.load
            torch.load = original_torch_load

        # Extract data and split indices
        data = dataset[0]
        split_idx = dataset.get_idx_split()

        # Convert adj_t to edge_index if needed
        if getattr(data, 'edge_index', None) is None:
            if getattr(data, 'adj_t', None) is None:
                raise ValueError("No graph structure: missing both edge_index and adj_t")

            print("Converting adj_t to edge_index...")
            coo = data.adj_t.to_torch_sparse_coo_tensor().coalesce()
            data.edge_index = coo.indices()
            # Remove self-loops
            mask = data.edge_index[0] != data.edge_index[1]
            data.edge_index = data.edge_index[:, mask]
            # Clear adj_t to save memory
            data.adj_t = None
            print(f"edge_index shape: {data.edge_index.shape}")

        # Remove dense edge_attr to save RAM if present
        if getattr(data, 'edge_attr', None) is not None:
            print(f"Removing edge_attr: {data.edge_attr.shape}")
            data.edge_attr = None

        # Normalize features
        if hasattr(data, 'x') and data.x is not None:
            row_sum = data.x.sum(dim=1, keepdim=True)
            row_sum[row_sum == 0] = 1  # Avoid division by zero
            data.x = data.x / row_sum

        # Flatten labels
        data.y = data.y.reshape(-1).contiguous()

        # Build train/val/test masks as 1D tensors
        train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)

        train_mask[split_idx["train"]] = True
        val_mask[split_idx["valid"]] = True
        test_mask[split_idx["test"]] = True

        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        # Update dataset.data to persist changes
        dataset.data = data
        dataset.transform = None

        # Print stats
        num_edges = data.edge_index.shape[1] // 2  # undirected graph
        print(f"\nDataset: {name_lower}")
        print(f"Nodes: {data.num_nodes:,}, Edges: {num_edges:,}")
        print(f"Features: {data.x.shape[1] if hasattr(data, 'x') else 0}")
        print(f"Classes: {data.y.max().item() + 1}")
        print(f"Train: {train_mask.sum().item():,}, Val: {val_mask.sum().item():,}, Test: {test_mask.sum().item():,}")
        print(f"RAM used: {__import__('psutil').Process().memory_info().rss / 1e9:.2f} GB")

        return dataset


    elif name_lower.startswith('ogbn1-'):
        import os
        import shutil

        # --- NOTE ---
        # Do NOT include T.ToSparseTensor() in final_transform here.
        # ToSparseTensor may remove/replace edge_index and cause later accesses to see None.
        final_transform = T.Compose([T.NormalizeFeatures()])

        # Patch torch.load temporarily to avoid weights_only=True errors
        original_torch_load = torch.load

        def patched_torch_load(*args, **kwargs):
            kwargs['weights_only'] = False
            return original_torch_load(*args, **kwargs)

        torch.load = patched_torch_load

        # Allow DataEdgeAttr in pickle (OGB edge attributes)
        try:
            from torch_geometric.data.data import DataEdgeAttr
            torch.serialization.add_safe_globals([DataEdgeAttr])
        except Exception as e:
            print("Warning: DataEdgeAttr not registered:", e)

        # Load dataset
        try:
            dataset = PygNodePropPredDataset(name=name_lower, root=full_path, transform=final_transform)
        except Exception as e:
            # Handle corrupted downloads
            if "BadZipFile" in str(e) or "not a zip file" in str(e):
                print("\n" + "=" * 60)
                print("ERROR: Corrupted OGB dataset download!")
                print("=" * 60)
                print(f"Delete directory: {full_path}")
                print("=" * 60 + "\n")

                if os.path.exists(full_path):
                    try:
                        shutil.rmtree(full_path)
                        print(f"Deleted corrupted directory: {full_path}")
                    except Exception as cleanup_error:
                        print(f"Failed to delete directory: {cleanup_error}")

                raise
            else:
                print(f"Error loading {name_lower}: {e}")
                raise
        finally:
            # Always restore torch.load
            torch.load = original_torch_load

        # Extract data and split indices
        data = dataset[0]
        split_idx = dataset.get_idx_split()

        # Remove dense edge_attr to save RAM if present
        if getattr(data, 'edge_attr', None) is not None:
            print(f"Removing edge_attr: {data.edge_attr.shape}")
            data.edge_attr = None

        # Helper to ensure persistent edge_index
        def ensure_persistent_edge_index(dataset, data):
            """
            Ensure data.edge_index exists and persists in dataset.data.
            This function:
              - builds edge_index from adj_t if needed
              - writes the fixed data back to dataset.data
              - disables dataset.transform to prevent reapplication
            """
            # If edge_index missing, try to derive from adj_t
            if getattr(data, 'edge_index', None) is None:
                if getattr(data, 'adj_t', None) is None:
                    raise ValueError("No graph structure: missing both edge_index and adj_t")

                # Convert adj_t -> torch sparse COO -> indices
                coo = data.adj_t.to_torch_sparse_coo_tensor().coalesce()
                ei = coo.indices()
                # Remove self-loops (if desired)
                mask = ei[0] != ei[1]
                ei = ei[:, mask]
                data.edge_index = ei
                print(f"[FIX] Built edge_index from adj_t: {data.edge_index.shape}")
            else:
                # If edge_index exists, ensure it's a Tensor (not None)
                print(f"[INFO] edge_index present: {data.edge_index.shape}")

            # Persist fixed data in dataset and disable transform to avoid reapplication:
            dataset.data = data
            dataset.transform = None  # IMPORTANT: prevents transform from being re-applied on each __getitem__
            return dataset, data

        # Apply the helper to persist edge_index
        dataset, data = ensure_persistent_edge_index(dataset, data)

        # Sanity debug prints to ensure persistence (these should remain valid later)
        print(f"[DEBUG] After persist -> edge_index is None? {data.edge_index is None}")
        print(f"[DEBUG] edge_index shape: {None if data.edge_index is None else data.edge_index.shape}")
        if data.edge_index is None:
            raise RuntimeError("edge_index unexpectedly None immediately after ensure_persistent_edge_index()")

        # Count edges (undirected assumption)
        num_edges = data.edge_index.shape[1] // 2 if data.edge_index.numel() else 0
        print(f"Nodes: {data.num_nodes}, Edges: {num_edges:,}")
        print(f"Adj type: {type(getattr(data, 'adj_t', None))}")
        print(f"RAM used: {__import__('psutil').Process().memory_info().rss / 1e9:.2f} GB")

        # Build train / valid / test masks
        data.y = data.y.reshape(-1).contiguous()
        train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
        train_mask[split_idx["train"]] = True
        val_mask[split_idx["valid"]] = True
        test_mask[split_idx["test"]] = True

        data.train_mask = train_mask
        data.val_mask = val_mask
        data.test_mask = test_mask

        # Persist masks into dataset.data as well
        dataset.data.train_mask = train_mask
        dataset.data.val_mask = val_mask
        dataset.data.test_mask = test_mask

        # Final safety debug: reading dataset[0] should return the same data (no transform applied)
        d0 = dataset[0]
        print(f"[DEBUG FINAL] dataset[0].edge_index is None? {getattr(d0, 'edge_index', None) is None}")
        if getattr(d0, 'edge_index', None) is None:
            # This indicates transforms are still being applied somewhere else.
            raise RuntimeError(
                "edge_index vanished after reloading dataset[0]. This implies some external code or transform "
                "is still replacing data (e.g., another transform applied elsewhere)."
            )
        if hasattr(data, 'adj_t') and data.adj_t is not None:
            coo = data.adj_t.to_torch_sparse_coo_tensor().coalesce()
            data.edge_index = coo.indices()
            print(f"\nConverted adj_t to edge_index: {data.edge_index.shape[1]}")
        if hasattr(data, 'edge_index'):
            if data.edge_index is not None:
                print(f"edge index shape: {data.edge_index.shape}")
            else:
                print(f"edge index object: {data.edge_index}")
        else:
            print("No edge_index attribute found.")
        num_edges = data.edge_index.shape[1] // 2  # undirected graph

        print(f"Nodes: {data.num_nodes}, Edges: {num_edges:,}")
        print(f"Adj type: {type(data.adj_t) if hasattr(data, 'adj_t') else 'None'}")
        print(f"RAM used: {__import__('psutil').Process().memory_info().rss / 1e9:.2f} GB")
        print(f"Number of edges: {data.edge_index.shape[1]:,}")  # <-- NOW SAFE


    else:
        raise ValueError(f"Dataset '{name}' is not supported.")


    # Print dataset info for standard datasets
    print_dataset_split_info(dataset, name)
    return dataset


def get_planetoid_dataset_custom(name, normalize_features=False, transform=None, split="public",
                                 split_ratios=(0.8, 0.1, 0.1)):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
    if split == 'complete':
        dataset = Planetoid(path, name.capitalize())

        # Validate split ratios
        train_ratio, val_ratio, test_ratio = split_ratios
        if not np.isclose(train_ratio + val_ratio + test_ratio, 1.0):
            raise ValueError(f"Split ratios must sum to 1.0, got {sum(split_ratios)}")

        num_nodes = dataset[0].num_nodes

        # Create random permutation of node indices
        indices = np.random.permutation(num_nodes)

        # Calculate split sizes
        train_size = int(train_ratio * num_nodes)
        val_size = int(val_ratio * num_nodes)

        # Get indices for each split
        train_indices = indices[:train_size]
        val_indices = indices[train_size:train_size + val_size]
        test_indices = indices[train_size + val_size:]

        # Reset all masks
        dataset[0].train_mask.fill_(False)
        dataset[0].val_mask.fill_(False)
        dataset[0].test_mask.fill_(False)

        # Set masks based on random indices
        dataset[0].train_mask[train_indices] = True
        dataset[0].val_mask[val_indices] = True
        dataset[0].test_mask[test_indices] = True

    else:
        dataset = Planetoid(path, name.capitalize(), split=split)

    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform
    return dataset


def get_planetoid_dataset(name, normalize_features=False, transform=None, split="public"):
    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', name)
    if split == 'complete':
        dataset = Planetoid(path, name.capitalize())
        dataset[0].train_mask.fill_(False)
        dataset[0].train_mask[:dataset[0].num_nodes - 1000] = 1
        dataset[0].val_mask.fill_(False)
        dataset[0].val_mask[dataset[0].num_nodes - 1000:dataset[0].num_nodes - 500] = 1
        dataset[0].test_mask.fill_(False)
        dataset[0].test_mask[dataset[0].num_nodes - 500:] = 1
    else:
        dataset = Planetoid(path, name.capitalize(), split=split)

    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform
    return dataset


def get_planetoid_custom_split(dataset_name, normalize_features=False, transform=None,
                               train_ratio=0.48, val_ratio=0.32, test_ratio=0.20,
                               random_seed=42):
    """
    Get Planetoid dataset (Cora, CiteSeer, or PubMed) with custom split ratios.
    """
    valid_datasets = ['cora', 'citeseer', 'pubmed']
    if dataset_name.lower() not in valid_datasets:
        raise ValueError(f"Dataset must be one of {valid_datasets}, got {dataset_name}")

    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset_name)
    dataset = Planetoid(path, dataset_name.capitalize())

    # Set random seed for reproducibility
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)

    data = dataset[0]
    num_nodes = data.num_nodes

    # Calculate exact number of nodes for each split
    train_size = int(num_nodes * train_ratio)
    val_size = int(num_nodes * val_ratio)
    test_size = num_nodes - train_size - val_size

    # Create random permutation of node indices
    indices = torch.randperm(num_nodes)

    # Split indices
    train_indices = indices[:train_size]
    val_indices = indices[train_size:train_size + val_size]
    test_indices = indices[train_size + val_size:]

    # Reset all masks
    data.train_mask.fill_(False)
    data.val_mask.fill_(False)
    data.test_mask.fill_(False)

    # Set new masks
    data.train_mask[train_indices] = True
    data.val_mask[val_indices] = True
    data.test_mask[test_indices] = True

    # Apply transforms
    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform

    return dataset


def get_planetoid_stratified_split(dataset_name, normalize_features=False, transform=None,
                                   train_ratio=0.48, val_ratio=0.32, test_ratio=0.20,
                                   random_seed=42):
    """
    Get Planetoid dataset with stratified split (maintains class distribution).
    """
    from sklearn.model_selection import train_test_split

    valid_datasets = ['cora', 'citeseer', 'pubmed']
    if dataset_name.lower() not in valid_datasets:
        raise ValueError(f"Dataset must be one of {valid_datasets}, got {dataset_name}")

    path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset_name)
    dataset = Planetoid(path, dataset_name.capitalize())

    data = dataset[0]
    num_nodes = data.num_nodes

    # Get node indices and labels
    node_indices = torch.arange(num_nodes)
    labels = data.y.numpy()

    # First split: separate train from (val + test)
    train_indices, temp_indices, _, temp_labels = train_test_split(
        node_indices, labels,
        test_size=1 - train_ratio,
        random_state=random_seed,
        stratify=labels
    )

    # Second split: separate val from test
    val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)
    val_indices, test_indices = train_test_split(
        temp_indices,
        test_size=1 - val_ratio_adjusted,
        random_state=random_seed,
        stratify=temp_labels
    )

    # Reset all masks
    data.train_mask.fill_(False)
    data.val_mask.fill_(False)
    data.test_mask.fill_(False)

    # Set new masks
    data.train_mask[train_indices] = True
    data.val_mask[val_indices] = True
    data.test_mask[test_indices] = True

    # Apply transforms
    if transform is not None and normalize_features:
        dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
    elif normalize_features:
        dataset.transform = T.NormalizeFeatures()
    elif transform is not None:
        dataset.transform = transform

    return dataset


def print_class_distribution(dataset, dataset_name="Dataset"):
    """
    Print the class distribution for each split.
    """
    data = dataset[0]
    labels = data.y

    print(f"\nClass Distribution for {dataset_name}:")
    print(f"{'Class':<8} {'Train':<8} {'Val':<8} {'Test':<8} {'Total':<8}")
    print("-" * 48)

    num_classes = labels.max().item() + 1
    for class_id in range(num_classes):
        class_mask = (labels == class_id)
        train_count = (data.train_mask & class_mask).sum().item()
        val_count = (data.val_mask & class_mask).sum().item()
        test_count = (data.test_mask & class_mask).sum().item()
        total_count = class_mask.sum().item()

        print(f"{class_id:<8} {train_count:<8} {val_count:<8} {test_count:<8} {total_count:<8}")


def print_dataset_split_info(dataset, dataset_name="Dataset"):
    """Print detailed information about dataset splits including counts and percentages."""
    data = dataset[0]
    total_nodes = data.num_nodes

    train_count = data.train_mask.sum().item()
    val_count = data.val_mask.sum().item()
    test_count = data.test_mask.sum().item()

    train_pct = (train_count / total_nodes) * 100
    val_pct = (val_count / total_nodes) * 100
    test_pct = (test_count / total_nodes) * 100

    print(f"\n{'=' * 50}")
    print(f"Dataset: {dataset_name}")
    print(f"{'=' * 50}")
    print(f"Total nodes: {total_nodes:,}")
    print(f"Number of features: {data.num_features if hasattr(data, 'num_features') else data.x.shape[1]}")

    if hasattr(dataset, 'num_classes'):
        print(f"Number of classes: {dataset.num_classes}")
    elif hasattr(data, 'y'):
        print(f"Number of classes: {data.y.max().item() + 1}")
    print(f"\nSplit Information:")
    print(f"{'Split':<12} {'Count':<8} {'Percentage':<12}")
    print(f"{'-' * 32}")
    print(f"{'Training':<12} {train_count:<8} {train_pct:<12.2f}%")
    print(f"{'Validation':<12} {val_count:<8} {val_pct:<12.2f}%")
    print(f"{'Test':<12} {test_count:<8} {test_pct:<12.2f}%")
    print(f"{'=' * 50}\n")


if __name__ == '__main__':
    # Test Planetoid datasets
    planetoid_names = ['Cora', 'CiteSeer', 'PubMed']
    for name in planetoid_names:
        print(f"\n{'#' * 60}")
        print(f"Testing {name}")
        print(f"{'#' * 60}")
        dataset = get_dataset(name, split='public')

    # Test Amazon datasets
    amazon_names = ['Photo', 'Computers']
    for name in amazon_names:
        print(f"\n{'#' * 60}")
        print(f"Testing {name}")
        print(f"{'#' * 60}")
        dataset = get_dataset(name)

    # Test Coauthor datasets
    coauthor_names = ['CS', 'Physics']
    for name in coauthor_names:
        print(f"\n{'#' * 60}")
        print(f"Testing {name}")
        print(f"{'#' * 60}")
        dataset = get_dataset(name)

    # Test WikiCS
    print(f"\n{'#' * 60}")
    print(f"Testing WikiCS")
    print(f"{'#' * 60}")
    dataset = get_dataset('WikiCS')

    # Test Reddit
    print(f"\n{'#' * 60}")
    print(f"Testing Reddit")
    print(f"{'#' * 60}")
    dataset = get_dataset('Reddit')