import argparse
import os

import numpy as np

# Enable CUDA debugging
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
import torch
from pytorch_lightning.loggers import WandbLogger
from torch import nn
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.optim import AdamW
from torch_geometric.data import Batch, Data
from torchmetrics import Accuracy

import wandb
from manifold_transformers import models, mt_datasets


def str2bool(value):
    if isinstance(value, bool):
        return value
    lowered = value.lower()
    if lowered in {"true", "1", "yes", "y"}:
        return True
    if lowered in {"false", "0", "no", "n"}:
        return False
    raise argparse.ArgumentTypeError("Boolean value expected.")


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def _resolve_model_selections(config):
    """Ensure backbone/positional-encoding selectors are provided."""

    backbone = config.get("backbone")
    posenc = config.get("posenc")

    if not backbone:
        raise ValueError("Please specify --backbone.")
    if posenc is None:
        raise ValueError("Please specify --posenc.")

    config["backbone"] = backbone
    config["posenc"] = posenc
    config["model_type"] = backbone

    return config


class GraphTaskModule(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.save_hyperparameters(config)

        # To know how to handle training batches.
        self.using_neighbor_loader = config["dataloader_type"] == "neighbor"

        self.loss_nodes = config.get("loss_nodes", "all")
        if self.loss_nodes not in {"all", "seeds"}:
            raise ValueError("--loss_nodes must be one of {'all', 'seeds'}")

        task = self._infer_task_from_dataset(config["dataset"])

        self.model = models.create_model(
            task=task,
            config=config,
            backbone=config["backbone"],
            posenc=config["posenc"],
        )

        self.criterion = CrossEntropyLoss()

        self.accuracy = Accuracy(task="multiclass", num_classes=config["num_classes"])

        self.LOG_FLAGS = {
            "sync_dist": True,
            "on_step": True,
            "on_epoch": False,
            "prog_bar": False,
        }
        self.LOG_FLAGS_PROG_BAR = {**self.LOG_FLAGS, "prog_bar": True}

        # Evaluation logs should not advance step; aggregate per epoch
        self.EVAL_LOG_FLAGS = {
            "sync_dist": True,
            "on_step": False,
            "on_epoch": True,
            "prog_bar": False,
        }
        self.EVAL_LOG_FLAGS_PROG_BAR = {**self.EVAL_LOG_FLAGS, "prog_bar": True}

        # Track for gradient logging
        self.log_gradients_every_n_steps = 100
        self.global_step_counter = 0
        self.log_debug_batch = config.get("debug", False)
        self.is_wandb_enabled = not config.get("disable_wandb", False)

    def _infer_task_from_dataset(self, dataset_name):
        """
        Infer the task type based on the dataset name.

        Args:
            dataset_name (str): Name of the dataset

        Returns:
            str: Task type ('node_classification' or 'graph_classification')
        """
        if dataset_name == "reddit-binary":
            return "graph_classification"
        elif dataset_name in ["arxiv", "ogbn-arxiv", "ogbn-mag","snap-patents", "arxiv-year"]:
            return "node_classification"
        else:
            # Default fallback
            return "node_classification"

    def on_after_backward(self):
        # Skip gradient logging entirely when WandB is disabled for performance
        if not self.is_wandb_enabled:
            return

        # Log gradients explicitly every n steps
        self.global_step_counter += 1
        if self.global_step_counter % self.log_gradients_every_n_steps == 0:
            for name, param in self.named_parameters():
                if (
                    param.grad is not None
                    and self.logger is not None
                    and hasattr(self.logger, "experiment")
                ):
                    # Only log if we have a proper WandB logger
                    if hasattr(self.logger.experiment, "log"):
                        self.logger.experiment.log(
                            {
                                f"gradients/histograms/{name}": wandb.Histogram(
                                    param.grad.cpu().numpy()
                                ),
                                f"gradients/norms/{name}": param.grad.norm().item(),
                            },
                            step=self.global_step,
                            commit=False,
                        )

    def forward(self, data):
        return self.model(data)

    def _select_loss_scope(self, data, logits):
        """Determine which nodes contribute to the loss for the current batch."""
        y_all = data.y.reshape(-1)
        logits_all = logits

        if (
            self.loss_nodes == "seeds"
            and self.using_neighbor_loader
            and getattr(data, "batch_size", None)
            and hasattr(data, "num_sampled_nodes")
        ):
            seed_count = int(data.batch_size)
            return logits_all[:seed_count], y_all[:seed_count]

        return logits_all, y_all

    def training_step(self, data, batch_idx):
        out = self(data)

        out_target, y_target = self._select_loss_scope(data, out)

        loss = self.criterion(out_target, y_target)
        preds = torch.argmax(out_target, dim=1)
        acc = self.accuracy(preds, y_target)

        self.log("train/loss", loss, **self.LOG_FLAGS_PROG_BAR)
        self.log("train/acc", acc, **self.LOG_FLAGS_PROG_BAR)

        debug_mode = self.log_debug_batch and self.is_wandb_enabled and batch_idx == 0 and data.x.shape[0] < 10000 and self.logger is not None
        if debug_mode:
            self._log_debug_overfit_batch(preds, y_target, loss, acc)

        self._log_gradients(batch_idx)

        return loss

    def validation_step(self, batch, batch_idx):
        if isinstance(batch, Batch):  # multi-graph batch
            groups = [(int(batch.num_nodes), [batch])]
        else:  # single-graph full validation
            groups = [(int(batch.sample_size), [batch])]
        return self.eval_step(groups, batch_idx, stage="val")

    def eval_step(self, size_groups, batch_idx, stage):
        """
        Evaluate and log transferability metrics.
        Groups all samples (from single- or multi-graph datasets)
        by their sample_fraction, computes mean/std for each group,
        and logs results.
        """
        # Flatten samples from all size groups
        all_samples = [s for _, group in size_groups for s in group]

        # Group by sample_fraction (default = 1.0)
        grouped = {}
        for sample in all_samples:
            frac = float(getattr(sample, "sample_fraction", 1.0))
            grouped.setdefault(frac, []).append(sample)

        metrics_by_frac = {}
        for frac, samples in grouped.items():
            metrics = self._compute_transfer_metrics(samples)
            if not metrics:
                continue
            metrics_by_frac[frac] = metrics

            if stage == "test":
                loss, loss_std, acc, acc_std, n = metrics
                self.log(f"transferability/{stage}_loss_size{frac:.2f}", loss, batch_size=n, **self.EVAL_LOG_FLAGS)
                self.log(f"transferability/{stage}_acc_size{frac:.2f}", acc, batch_size=n, **self.EVAL_LOG_FLAGS)
                self.log(f"transferability/{stage}_loss_std_size{frac:.2f}", loss_std, batch_size=n, **self.EVAL_LOG_FLAGS)
                self.log(f"transferability/{stage}_acc_std_size{frac:.2f}", acc_std, batch_size=n, **self.EVAL_LOG_FLAGS)

        # Log the largest fraction to progress bar
        frac = max(metrics_by_frac)
        loss, _, acc, _, _ = metrics_by_frac[frac]
        self.log(f"{stage}/loss", loss, **self.EVAL_LOG_FLAGS_PROG_BAR)
        self.log(f"{stage}/acc", acc, **self.EVAL_LOG_FLAGS_PROG_BAR)
        return loss

    def test_step(self, batch, batch_idx):
        grouped: dict[int, list[Data | Batch]] = {}
        for sample in batch:
            size_key = int(sample.sample_size)
            grouped.setdefault(size_key, []).append(sample)
        ordered = sorted(grouped.items())
        return self.eval_step(ordered, batch_idx, stage="test")

    def _compute_transfer_metrics(self, samples):
        if not samples:
            raise ValueError("No samples provided")

        loss_values = []
        acc_values = []
        counts = []

        with torch.no_grad():
            for sample in samples:
                sample = sample.to(self.device)

                labels = sample.y.view(-1)

                logits = self(sample)
                loss_val = self.criterion(logits, labels).detach()
                preds = torch.argmax(logits, dim=1)
                correct = (preds == labels).sum().detach().float()
                count = float(labels.numel())

                loss_values.append(loss_val)
                acc_values.append(correct / count)
                counts.append(count)

        loss_tensor = torch.stack(loss_values)
        acc_tensor = torch.stack(acc_values)
        weight_tensor = torch.tensor(
            counts,
            dtype=loss_tensor.dtype,
            device=loss_tensor.device,
        )

        weight_sum = weight_tensor.sum().clamp(min=1.0)
        loss_mean = (loss_tensor * weight_tensor).sum() / weight_sum
        acc_mean = (acc_tensor * weight_tensor).sum() / weight_sum

        if loss_tensor.numel() > 1:
            loss_std = torch.sqrt(
                torch.sum(weight_tensor * (loss_tensor - loss_mean) ** 2)
                / weight_sum
            )
            acc_std = torch.sqrt(
                torch.sum(weight_tensor * (acc_tensor - acc_mean) ** 2)
                / weight_sum
            )
        else:
            loss_std = torch.zeros_like(loss_mean)
            acc_std = torch.zeros_like(acc_mean)

        return loss_mean, loss_std, acc_mean, acc_std, int(weight_sum.item())

    def configure_optimizers(self):
        #optimizer = Adam(self.parameters(), lr=self.config["learning_rate"])
        optimizer = AdamW(
            self.parameters(),
            lr=self.config["learning_rate"],
            betas=(0.9, 0.999),
            eps=1e-8,
            weight_decay=0.01
        )

        # Configure gradient clipping if specified
        if self.config.get("gradient_clip_val", 0.0) > 0.0:
            return {
                "optimizer": optimizer,
                "gradient_clip_val": self.config["gradient_clip_val"],
            }
        else:
            return optimizer

    def _log_gradients(self, batch_idx):
        # Skip gradient logging entirely when WandB is disabled for performance
        if not self.is_wandb_enabled:
            return

        # Log gradients directly after backward
        if (
            batch_idx % self.log_gradients_every_n_steps == 0
            and self.logger is not None
        ):
            # Make sure to do this after computing the loss but before returning it
            # The backward pass happens automatically when returning the loss
            for name, param in self.named_parameters():
                if (
                    param.requires_grad
                    and hasattr(param, "grad")
                    and param.grad is not None
                    and isinstance(self.logger, WandbLogger)
                ):
                    grad_norm = param.grad.norm().item()
                    # Only log if we have a proper WandB logger
                    if hasattr(self.logger.experiment, "log"):
                        self.logger.experiment.log(
                            {
                                f"gradients/histograms/{name}": wandb.Histogram(
                                    param.grad.detach().cpu().numpy()
                                ),
                                f"gradients/norms/{name}": grad_norm,
                            },
                            step=self.global_step,
                            commit=False,
                        )
                    # Also log scalar values for the gradient norms
                    self.log(f"gradients/norms/{name}", grad_norm, **self.LOG_FLAGS)

    def _log_debug_overfit_batch(self, preds, y_target, loss, acc):
        if hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "log"):
            preds_cpu = preds.detach().cpu()
            y_target_cpu = y_target.detach().cpu()

            comparison_data = [
                [i, y.item(), p.item()]
                for i, (y, p) in enumerate(zip(y_target_cpu, preds_cpu))
            ]
            comparison_table = wandb.Table(
                columns=["idx", "actual", "predicted"], data=comparison_data
            )
            self.logger.experiment.log(
                {"debug/predictions_vs_labels": comparison_table},
                step=self.global_step,
                commit=False,
            )
            self.logger.experiment.log(
                {"debug/overfit_batch_loss": loss.item()},
                step=self.global_step,
                commit=False,
            )
            self.logger.experiment.log(
                {"debug/overfit_batch_acc": acc.item()},
                step=self.global_step,
                commit=False,
            )


def train(config):
    for fold_idx in range(config["k_fold_stratified_split"]):
        fold_config = config.copy()
        fold_config["fold_idx"] = fold_idx
        run(fold_config)
    
def run(config):
    # Set seed for reproducibility
    if config.get("seed") is not None:
        set_seed(config["seed"])

    # Configure parallelism settings BEFORE creating datamodule
    if config.get("disable_parallelism", False):
        config["num_workers"] = 0  # Force no workers for debugging

    # Resolve output directories for trainer and loggers
    output_root = os.path.abspath(config.get("output_dir", "."))
    config["output_dir"] = output_root

    lightning_root = os.path.join(output_root, "lightning_logs")
    os.makedirs(lightning_root, exist_ok=True)

    wandb_root = os.path.join(output_root, "wandb")
    os.makedirs(wandb_root, exist_ok=True)
    os.environ["WANDB_DIR"] = wandb_root

    # Create data module and infer model dimensions
    datamodule = mt_datasets.create_datamodule_from_config(config)

    # Set up run name and logger
    run_name = (
        config["run_name"]
        if config["run_name"] is not None
        else f"{config['model_type'].upper()} {config['dataset'].upper()} {config['train_downsample_fraction']}-fold {config['fold_idx'] + 1}"
    )

    if config.get("disable_wandb", False):
        wandb_logger = None
    else:
        wandb_logger = WandbLogger(
            project=config["project"],
            entity=config["entity"],
            tags=[config["experiment_tag"]],
            group=f"{config['model_type'].upper()} {config['dataset'].upper()} {config['train_downsample_fraction']}",
            config=config,
            name=run_name,
            save_dir=wandb_root,
        )

    # Initialize model and trainer
    model = GraphTaskModule(config)

    # Configure trainer based on parallelism setting
    if config.get("disable_parallelism", False) or config["devices"] == "1":
        trainer_kwargs = {
            "strategy": "auto",  # Use single device strategy, faster for single GPU
            "accelerator": config["accelerator"],
            "devices": 1,
        }
    else:
        trainer_kwargs = {
            "strategy": config["strategy"],
            "accelerator": config["accelerator"],
            "devices": config["devices"],
        }

    callbacks = []
    if config["early_stopping_mode"] != "none":
        callbacks.append(
            EarlyStopping(
                monitor=config["early_stopping_monitor"],
                patience=config["early_stopping_patience"],
                mode=config["early_stopping_mode"],
                min_delta=config["early_stopping_min_delta"],
            )
        )

    trainer = pl.Trainer(
        default_root_dir=lightning_root,
        min_epochs=config["min_epochs"],
        max_epochs=config["max_epochs"],
        logger=wandb_logger,
        log_every_n_steps=1,  # batch size is usually 1 so
        val_check_interval=1.0,  # Check validation once per epoch
        check_val_every_n_epoch=1,
        enable_checkpointing=False,  # Disable checkpointing for performance
        enable_progress_bar=True,
        enable_model_summary=True,
        profiler=(
            config["profiler"] if config["profiler"] != "none" else None
        ),  # Add profiler for performance analysis
        gradient_clip_val=(
            config.get("gradient_clip_val", 0.0)
            if config.get("gradient_clip_val", 0.0) > 0.0
            else None
        ),
        num_sanity_val_steps=0,
        callbacks=callbacks,
        **trainer_kwargs,
    )

    # Train and test
    if wandb_logger is not None and config.get("debug", False):
        # Only watch model in debug mode to reduce overhead
        wandb_logger.watch(model.model, log="gradients", log_freq=10)
    trainer.fit(model, datamodule=datamodule)
    trainer.test(model, datamodule=datamodule)
    if wandb_logger is not None:
        wandb_logger.experiment.finish()



def main():
    parser = argparse.ArgumentParser(
        description="Train node classification models on graph data"
    )

    # Parent parser for common (non-model-specific) arguments
    parent_parser = parser

    # Dataset parameter
    parent_parser.add_argument(
        "--dataset",
        type=str,
        default="arxiv-year",
        choices=["arxiv", "ogbn-mag", "reddit-binary", "snap-patents", "arxiv-year", "chameleon", "pokec"],
        help="Dataset to use.",
    )
    parent_parser.add_argument(
        "--num_workers",
        type=int,
        default=0,
        help="Number of workers for dataloader. Defaults to 0 because it's best for a small number of graphs/batches.",
    )
    parent_parser.add_argument(
        "--k_fold_stratified_split",
        type=int,
        default=1,
        help="Number of folds for stratified split.",
    )
    # TRAIN/TEST/VAL RATIOS.
    # in terms of #graphs for multi-graph datasets, #nodes for single graph datasets.
    parent_parser.add_argument(
        "--val_ratio",
        type=float,
        default=0.1,
        help="Ratio of dataset to use for validation.",
    )
    parent_parser.add_argument(
        "--test_ratio",
        type=float,
        default=0.45,
        help="Ratio of dataset to use for testing.",
    )
    ## DOWNSAMPLE PARAMETERS
    parent_parser.add_argument(
        "--max_memory_nodes",
        type=int,
        default=1_000_000,
        help="Max number of nodes for training/testing, determined by memory, (single-graph datasets only).",
    )
    # in terms of #nodes per graph for multi-graph datasets, #nodes in split for single-graph datasets.
    parent_parser.add_argument(
        "--check_val_every_n_epoch",
        type=int,
        default=5,
        help="Check validation every n epochs.",
    )
    parent_parser.add_argument(
        "--train_downsample_fraction",
        type=float,
        default=1.0,
        help="Fraction of training nodes to use after train/val/test split.",
    )
    parent_parser.add_argument(
        "--test_downsample_fractions",
        type=float,
        nargs="+",
        default=[0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 
                 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0],
        help="Fraction of nodes to use for transferability testing",
    )
    parent_parser.add_argument(
        "--test_num_batches_per_size",
        type=int,
        default=5,
        help="For every test fraction, sample this many subgraphs and average their metrics.",
    )

    # Training parameters
    parent_parser.add_argument(
        "--learning_rate", type=float, default=0.01, help="Learning rate for optimizer"
    )
    parent_parser.add_argument(
        "--dropout", type=float, default=0.5, help="Dropout rate for regularization. Affects all non-attn dropout parameters."
    )
    parent_parser.add_argument(
        "--max_epochs", type=int, default=300, help="Maximum number of training epochs"
    )
    parent_parser.add_argument(
        "--min_epochs", type=int, default=1, help="Minimum number of training epochs"
    )
    # Batch sizing: deprecated absolute batch_size, use --train_size_pct instead
    parent_parser.add_argument(
        "--batch_size",
        type=int,
        default=None,
        help="[DEPRECATED] Use --train_size_pct instead",
    )
    parent_parser.add_argument(
        "--train_size_pct",
        type=float,
        default=1e-4,
        help="NeighborLoader seed node batch size as a fraction of train nodes",
    )
    # Logging parameters
    parent_parser.add_argument(
        "--run_name", type=str, default=None, help="Name for the wandb run"
    )
    parent_parser.add_argument(
        "--entity", type=str, default=None, help="WandB entity name"
    )
    parent_parser.add_argument(
        "--project",
        type=str,
        default="anonymous-project",
        help="WandB project name",
    )
    parent_parser.add_argument(
        "--experiment_tag", type=str, default="dataset", help="WandB experiment tag"
    )
    parent_parser.add_argument(
        "--output_dir",
        type=str,
        default=".",
        help="Base directory to store Lightning logs and WandB files.",
    )
    # Training infrastructurezr
    parent_parser.add_argument(
        "--strategy",
        type=str,
        default="ddp_find_unused_parameters_true",
        help="Training strategy (e.g., 'ddp', 'deepspeed', 'ddp_find_unused_parameters_true')",
    )
    parent_parser.add_argument(
        "--accelerator",
        type=str,
        default="gpu",
        help="Accelerator type (e.g., 'gpu', 'cpu')",
    )
    parent_parser.add_argument(
        "--devices", type=str, default="1", help="Number of devices to use for training"
    )
    # Debug parameter
    parent_parser.add_argument("--debug", action="store_true", help="Run in debug mode")
    parent_parser.add_argument(
        "--profiler",
        type=str,
        default="none",
        help="Enable profiler ('simple','advanced','pytorch','none')",
    )
    # Seed parameter
    parent_parser.add_argument(
        "--seed", type=int, default=42, help="Random seed for reproducibility"
    )
    # Disable wandb logging
    parent_parser.add_argument(
        "--disable_wandb", action="store_true", help="Disable wandb logging"
    )
    # Disable parallelism for debugging
    parent_parser.add_argument(
        "--disable_parallelism",
        action="store_true",
        help="Disable parallelism for debugging",
    )
    # Gradient clipping
    parent_parser.add_argument(
        "--gradient_clip_val",
        type=float,
        default=5.0,
        help="Gradient clipping value (0.0 means no clipping)",
    )
    parent_parser.add_argument(
        "--early_stopping_patience",
        type=int,
        default=10,
        help="Number of validation checks with no improvement before stopping.",
    )
    parent_parser.add_argument(
        "--early_stopping_monitor",
        type=str,
        default="val/loss",
        help="Validation metric name to monitor for early stopping.",
    )
    parent_parser.add_argument(
        "--early_stopping_mode",
        type=str,
        choices=["min", "max", "none"],
        default="none",
        help="Optimization direction for the monitored metric; use 'none' to disable early stopping.",
    )
    parent_parser.add_argument(
        "--early_stopping_min_delta",
        type=float,
        default=0.0,
        help="Minimum change in the monitored metric to qualify as improvement.",
    )

    # Graph classification specific parameters (inferred automatically)
    parent_parser.add_argument(
        "--pooling",
        type=str,
        default="mean",
        choices=["mean", "max", "sum"],
        help="Pooling method for graph classification (mean, max, sum)",
    )
    parent_parser.add_argument(
        "--dataloader_type",
        type=str,
        default="dataloader",
        choices=["neighbor", "dataloader"],
        help="Type of dataloader to use: 'neighbor' for NeighborLoader, 'dataloader' for DataLoader",
    )
    parent_parser.add_argument(
        "--neighbor_num_neighbors",
        type=int,
        nargs="+",
        default=None,
        help="Neighbor counts per hop for NeighborLoader (ignored when using the standard DataLoader).",
    )
    parent_parser.add_argument(
        "--loss_nodes",
        type=str,
        choices=["all", "seeds"],
        default="all",
        help="When using NeighborLoader, choose whether the loss uses all sampled nodes or only the seed nodes.",
    )
    parent_parser.add_argument(
        "--force_undirected",
        type=str2bool,
        nargs="?",
        const=True,
        default=False,
        help="Convert dataset edge_index to undirected before constructing loaders.",
    )

    parser.add_argument(
        "--backbone",
        type=str,
        choices=models.BACKBONE_CHOICES,
        default=None,
        help="Backbone architecture to train (e.g., dense_gt, sparse_gt, gcn).",
    )
    parser.add_argument(
        "--posenc",
        type=str,
        choices=models.POSENC_CHOICES,
        default=None,
        help="Positional encoding to pair with the selected backbone (none, rpearl, data).",
    )

    # Backbone-specific hyperparameters
    gcn_group = parser.add_argument_group("GCN Backbone")
    gcn_group.add_argument(
        "--gcn_hidden_channels",
        type=int,
        default=256,
        help="Number of hidden channels in GCN",
    )
    gcn_group.add_argument(
        "--gcn_num_layers", type=int, default=3, help="Number of GCN layers"
    )
    gcn_group.add_argument(
        "--gcn_k",
        type=int,
        default=3,
        help="Order K for TAGConv used inside the GCN backbone",
    )

    transformer_group = parser.add_argument_group("Laplacian Transformer Backbone")
    transformer_group.add_argument(
        "--nhead", type=int, default=8, help="Number of attention heads in transformer"
    )
    transformer_group.add_argument(
        "--transformer_d_model",
        type=int,
        default=64,
        help="Dimension of model in transformer",
    )
    transformer_group.add_argument(
        "--dim_feedforward",
        type=int,
        default=512,
        help="Dimension of feedforward network in transformer",
    )
    transformer_group.add_argument(
        "--transformer_num_layers",
        type=int,
        default=1,
        help="Number of transformer layers",
    )

    dense_group = parser.add_argument_group("Dense Graph Transformer Backbone")
    dense_group.add_argument(
        "--dense_d_model",
        type=int,
        default=128,
        help="Dimension of the dense graph transformer backbone.",
    )
    dense_group.add_argument(
        "--dense_nhead",
        type=int,
        default=4,
        help="Number of attention heads in the dense backbone.",
    )
    dense_group.add_argument(
        "--dense_dim_feedforward",
        type=int,
        default=128,
        help="Dimension of the feedforward network inside the dense backbone.",
    )
    dense_group.add_argument(
        "--dense_transformer_num_layers",
        type=int,
        default=3,
        help="Number of transformer layers in the dense backbone.",
    )
    dense_group.add_argument(
        "--dense_attn_dropout",
        type=float,
        default=0.15,
        help="Attention dropout probability used by the dense backbone.",
    )
    exphormer_group = parser.add_argument_group("Exphormer Backbone")
    exphormer_group.add_argument(
        "--exphormer_d_model",
        type=int,
        default=64,
        help="Dimension of model in Exphormer",
    )
    exphormer_group.add_argument(
        "--exphormer_nhead",
        type=int,
        default=8,
        help="Number of attention heads in Exphormer",
    )
    exphormer_group.add_argument(
        "--exphormer_dim_feedforward",
        type=int,
        default=512,
        help="Dimension of feedforward network in Exphormer",
    )
    exphormer_group.add_argument(
        "--exphormer_num_layers",
        type=int,
        default=1,
        help="Number of transformer layers in Exphormer",
    )
    exphormer_group.add_argument(
        "--exphormer_exp_degree",
        type=int,
        default=5,
        help="Degree of expander graph in Exphormer",
    )
    exphormer_group.add_argument(
        "--exphormer_exp_algorithm",
        type=str,
        choices=["Random-d", "Random-d2", "Hamiltonian"],
        default="Random-d",
        help="Expander graph algorithm in Exphormer",
    )

    sparse_group = parser.add_argument_group("Sparse Graph Transformer (KHopGT)")
    sparse_group.add_argument(
        "--sparse_gt_d_model",
        type=int,
        default=64,
        help="Dimension of model in Sparse GT",
    )
    sparse_group.add_argument(
        "--sparse_gt_nhead",
        type=int,
        default=8,
        help="Number of attention heads in Sparse GT",
    )
    sparse_group.add_argument(
        "--sparse_gt_num_hops",
        type=int,
        default=2,
        help="Number of hops for attention window in Sparse GT",
    )
    sparse_group.add_argument(
        "--sparse_gt_num_layers",
        type=int,
        default=1,
        help="Number of layers in Sparse GT",
    )
    sparse_group.add_argument(
        "--sparse_gt_attn_algorithm",
        type=str,
        choices=["naive", "sparse"],
        default="sparse",
        help="Attention implementation to use in Sparse GT (naive or sparse)",
    )
    sparse_group.add_argument(
        "--sparse_gt_attn_dropout",
        type=float,
        default=0.05,
        help="Attention dropout probability in Sparse GT attention",
    )
    sparse_group.add_argument(
        "--sparse_gt_random_graph",
        type=str,
        choices=["Random-d", "Random-d2", "None"],
        default="None",
        help="Random augmentation graph strategy for Sparse GT attention.",
    )
    sparse_group.add_argument(
        "--sparse_gt_random_graph_degree",
        type=int,
        default=5,
        help="Degree used when sampling the random augmentation graph.",
    )

    # Positional-encoding hyperparameters
    rpearl_group = parser.add_argument_group("RPearl Positional Encoding")
    rpearl_group.add_argument(
        "--rpearl_hidden_channels",
        type=int,
        default=128,
        help="Hidden channels for the RPearl positional encoding module.",
    )
    rpearl_group.add_argument(
        "--rpearl_num_layers",
        type=int,
        default=8,
        help="Number of layers in the RPearl positional encoding module.",
    )
    rpearl_group.add_argument(
        "--rpearl_num_samples",
        type=int,
        default=30,
        help="Number of random samples (M) for RandomGNNPositionalEncodings.",
    )

    data_pe_group = parser.add_argument_group("Data Positional Encoding")
    data_pe_group.add_argument(
        "--data_posenc_hidden_channels",
        type=int,
        default=128,
        help="Hidden channels for the data-driven positional encoding module.",
    )
    data_pe_group.add_argument(
        "--data_posenc_num_layers",
        type=int,
        default=8,
        help="Number of layers for the data-driven positional encoding module.",
    )

    mlp_group = parser.add_argument_group("MLP Baseline")
    mlp_group.add_argument(
        "--mlp_hidden_dim",
        type=int,
        default=128,
        help="Hidden dimension for MLP backbone",
    )
    mlp_group.add_argument(
        "--mlp_num_layers",
        type=int,
        default=3,
        help="Number of hidden layers in MLP backbone",
    )

    linkx_group = parser.add_argument_group("LINKX Backbone")
    linkx_group.add_argument(
        "--linkx_hidden_dim",
        type=int,
        default=128,
        help="Hidden dimension for LINKX",
    )
    linkx_group.add_argument(
        "--linkx_x_num_layers",
        type=int,
        default=2,
        help="Number of layers in feature MLP for LINKX",
    )
    linkx_group.add_argument(
        "--linkx_a_num_layers",
        type=int,
        default=2,
        help="Number of layers in adjacency MLP for LINKX",
    )

    args = parser.parse_args()
    config = vars(args)
    _resolve_model_selections(config)

    # Determine dataset modality: multi-graph vs single-graph
    multi_graph_datasets = {"reddit-binary"}
    is_multi_graph = config.get("dataset") in multi_graph_datasets

    # Validation and defaults for batch sizing
    if is_multi_graph:
        # Multi-graph: use absolute batch_size; keep legacy default
        if config.get("batch_size") is None:
            config["batch_size"] = 32
        # Guard against misuse of pct on multi-graph if user overrides default
        if config.get("train_size_pct", 1e-4) != 1e-4:
            raise ValueError(
                "--train_size_pct is only for single-graph datasets. Use --batch_size for multi-graph datasets."
            )
    else:
        # Single-graph datasets only require pct-based sizing unless using NeighborLoader
        if (
            config["dataloader_type"] != "neighbor"
            and config.get("batch_size") is not None
        ):
            raise ValueError(
                "--batch_size is deprecated for single-graph datasets when not using NeighborLoader. Please use --train_size_pct instead."
            )
    if config["test_num_batches_per_size"] <= 0:
        raise ValueError("--test_num_batches_per_size must be a positive integer.")
    train(config)


if __name__ == "__main__":
    main()

# TO DO maybe we don't need this
# Backward compatibility alias
NodeClassification = GraphTaskModule
