import logging
import time

import numpy as np
import torch
from torch_geometric.graphgym.checkpoint import (
    clean_ckpt,
    load_ckpt,
    save_ckpt,
)
from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.loss import compute_loss
from torch_geometric.graphgym.register import register_train
from torch_geometric.graphgym.utils.epoch import is_eval_epoch

from graphgps.loss.subtoken_prediction_loss import subtoken_cross_entropy


def percentile(t, q):
    # Calculate the k-th value based on the percentile q
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values.item()


def calculate_sparsities(sparsity, epoch, max_epoch_half, schedule_type):
    """Calculate sparsities for multi-mode (mm) based on the scheduling type."""
    if schedule_type == "linear_ascending":
        return [value * (epoch / max_epoch_half) for value in sparsity]
    elif schedule_type == "curve":
        curve_proportion = (max_epoch_half - epoch) / max_epoch_half
        scale_factor = np.sqrt(1 - curve_proportion**2)
        increments = [
            (sparsity[i] - sparsity[0]) / (1 - sparsity[0])
            for i in range(1, len(sparsity))
        ]
        sparsities = []
        for i, value in enumerate(sparsity):
            if i == 0:
                sparsities.append(value * scale_factor)
            else:
                start_point = increments[i - 1]
                sparsities.append(
                    start_point + (value - start_point) * scale_factor
                )
        return sparsities
    elif schedule_type == "linear_descent":
        return [
            1 - ((1 - value) * (epoch / max_epoch_half)) for value in sparsity
        ]
    elif schedule_type == "fixed":
        return sparsity
    return sparsity


def calculate_single_sparsity(sparsity, epoch, max_epoch_half, schedule_type):
    """Calculate a single sparsity value for single-mode (sm) based on the scheduling type."""
    if schedule_type == "linear_ascending":
        return sparsity * (epoch / max_epoch_half)
    elif schedule_type == "curve":
        curve_proportion = (max_epoch_half - epoch) / max_epoch_half
        scale_factor = np.sqrt(1 - curve_proportion**2)
        return sparsity * scale_factor
    elif schedule_type == "linear_descent":
        return 1 - ((1 - sparsity) * (epoch / max_epoch_half))
    elif schedule_type == "fixed":
        return sparsity
    return sparsity


def get_threshold(model, sparsity, epoch=None, keyword=""):
    epoch += 1
    max_epoch_half = cfg.optim.max_epoch // 2

    if cfg.slt.mm:
        if model.training and epoch <= max_epoch_half:
            sparsities = calculate_sparsities(
                sparsity, epoch, max_epoch_half, cfg.slt.sparse_scheduling
            )
        else:
            sparsities = sparsity

        threshold_list = []
        for value in sparsities:
            local = torch.cat(
                [
                    p.detach().flatten()
                    for name, p in model.named_parameters()
                    if hasattr(p, "is_score")
                    and p.is_score
                    and keyword in name
                ]
            )
            threshold = percentile(
                (
                    local.abs()
                    if cfg.slt.enable_abs_pruning or cfg.slt.sign_mask
                    else local
                ),
                value * 100,
            )
            threshold_list.append(threshold)
        return threshold_list

    elif cfg.slt.sm:
        if model.training and epoch <= max_epoch_half:
            sparsity_value = calculate_single_sparsity(
                sparsity[0], epoch, max_epoch_half, cfg.slt.sparse_scheduling
            )
        else:
            sparsity_value = sparsity[0]

        local = torch.cat(
            [
                p.detach().flatten()
                for name, p in model.named_parameters()
                if hasattr(p, "is_score") and p.is_score and keyword in name
            ]
        )
        threshold = percentile(
            (
                local.abs()
                if cfg.slt.enable_abs_pruning or cfg.slt.sign_mask
                else local
            ),
            sparsity_value * 100,
        )
        return threshold


def get_model_sparsity(model, threshold):
    """Calculates the sparsity of a model's parameters, considering only those with "score" in their names. # noqa

    Args:
        model: The PyTorch model to calculate sparsity for.
        threshold: The threshold below which a parameter is considered zero.

    Returns:
        The sparsity of the model's parameters, as a percentage.
    """

    local = []
    for name, p in model.named_parameters():
        if "score" in name:
            local.append(p.detach().flatten())
    local = torch.cat(local)
    abs_local = local.abs() if cfg.slt.enable_abs_pruning else local

    if isinstance(threshold, list):
        threshold = threshold[0]

    # Count the number of parameters below the threshold
    zero_count = (abs_local < threshold).sum().item()

    # Calculate sparsity as the percentage of zero parameters
    sparsity = (zero_count / abs_local.numel()) * 100

    return sparsity


def get_kinds_of_unique(model):
    local = []
    for name, p in model.named_parameters():
        if p.requires_grad is False:
            local.append(p.detach().flatten())
    local = torch.cat(local)

    kinds_of_unique = len(set(local.tolist()))

    return kinds_of_unique


def train_epoch(
    logger,
    loader,
    model,
    optimizer,
    scheduler,
    batch_accumulation,
    cur_epoch,
):
    model.train()

    optimizer.zero_grad()
    time_start = time.time()
    for iter, batch in enumerate(loader):
        batch.split = "train"
        batch.to(torch.device(cfg.accelerator))

        if cfg.slt.sm is True or cfg.slt.mm is True:
            if cfg.slt.pruning == "blockwise":
                if cfg.slt.mpnn:
                    mpnn_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "local_model",
                    )
                if cfg.slt.msa:
                    msa_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "self_attn",
                    )
                if cfg.slt.ffn:
                    ffn_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "ff_linear",
                    )
                if cfg.slt.encoder:
                    encoder_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "encoder",
                    )
                if cfg.slt.pred:
                    pred_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "post_mp",
                    )
                global_th = None
            elif cfg.slt.pruning == "global":
                global_th = get_threshold(
                    model, cfg.slt.linear_sparsity, cur_epoch, ""
                )
                mpnn_th = None
                msa_th = None
                ffn_th = None
                encoder_th = None
                pred_th = None

            pred, true = model(
                batch,
                cur_epoch,
                mpnn_th,
                msa_th,
                ffn_th,
                encoder_th,
                pred_th,
                global_th,
            )
        else:
            pred, true = model(batch)

        if cfg.dataset.name == "ogbg-code2":
            loss, pred_score = subtoken_cross_entropy(pred, true)
            _true = true
            _pred = pred_score
        else:
            loss, pred_score = compute_loss(pred, true)
            _true = true.detach().to("cpu", non_blocking=True)
            _pred = pred_score.detach().to("cpu", non_blocking=True)

        loss.backward()

        if ((iter + 1) % batch_accumulation == 0) or (iter + 1 == len(loader)):
            if cfg.optim.clip_grad_norm:
                torch.nn.utils.clip_grad_norm_(
                    model.parameters(), cfg.optim.clip_grad_norm_value
                )
            optimizer.step()
            optimizer.zero_grad()

        logger.update_stats(
            true=_true,
            pred=_pred,
            loss=loss.detach().cpu().item(),
            lr=scheduler.get_last_lr()[0],
            time_used=time.time() - time_start,
            params=cfg.params,
            dataset_name=cfg.dataset.name,
        )
        time_start = time.time()


def model_forward(batch, model, cur_epoch):
    if cfg.slt.pruning == "blockwise":
        mpnn_th, msa_th, ffn_th, encoder_th, pred_th = (
            get_thresholds_blockwise(model, cur_epoch)
        )
    elif cfg.slt.pruning == "global":
        global_th, mpnn_th, msa_th, ffn_th, encoder_th, pred_th = (
            get_thresholds_global(model, cur_epoch)
        )
    else:
        mpnn_th = msa_th = ffn_th = encoder_th = pred_th = global_th = None

    if cfg.slt.sm is True or cfg.slt.mm is True:
        return model(
            batch,
            cur_epoch,
            mpnn_th,
            msa_th,
            ffn_th,
            encoder_th,
            pred_th,
            global_th,
        )
    else:
        return model(batch)


def compute_loss_and_prepare_outputs(pred, true):
    if cfg.dataset.name == "ogbg-code2":
        loss, pred_score = subtoken_cross_entropy(pred, true)
        _true = true
        _pred = pred_score
    else:
        loss, pred_score = compute_loss(pred, true)
        _true = true.detach().to("cpu", non_blocking=True)
        _pred = pred_score.detach().to("cpu", non_blocking=True)
    return loss, _true, _pred


def get_thresholds_blockwise(model, cur_epoch):
    mpnn_th = (
        get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "local_model")
        if cfg.slt.mpnn
        else None
    )
    msa_th = (
        get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "self_attn")
        if cfg.slt.msa
        else None
    )
    ffn_th = (
        get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "ff_linear")
        if cfg.slt.ffn
        else None
    )
    encoder_th = (
        get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "encoder")
        if cfg.slt.encoder
        else None
    )
    pred_th = (
        get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "post_mp")
        if cfg.slt.pred
        else None
    )
    # global_th = None
    return mpnn_th, msa_th, ffn_th, encoder_th, pred_th


def get_thresholds_global(model, cur_epoch):
    global_th = get_threshold(model, cfg.slt.linear_sparsity, cur_epoch, "")
    return global_th, None, None, None, None, None


@torch.no_grad()
def eval_epoch(
    logger,
    loader,
    model,
    cur_epoch,
    mpnn_th,
    msa_th,
    ffn_th,
    encoder_th,
    pred_th,
    global_th,
    split="val",
    ood_classes=None,
):
    ood_matches = 0
    total_preds = 0

    model.eval()
    time_start = time.time()

    if ood_classes is not None:
        ood_classes = ood_classes.to(torch.device(cfg.accelerator))

    for iter, batch in enumerate(loader):
        batch.split = split
        batch.to(torch.device(cfg.accelerator))

        with torch.no_grad():
            if cfg.gnn.head == "inductive_edge":
                if cfg.slt.sm is True or cfg.slt.mm is True:
                    pred, true, extra_stats = model(
                        batch,
                        cfg.optim.max_epoch,
                        mpnn_th,
                        msa_th,
                        ffn_th,
                        encoder_th,
                        pred_th,
                        global_th,
                    )
                else:
                    pred, true, extra_stats = model(batch)
            else:
                if cfg.slt.sm is True or cfg.slt.mm is True:
                    pred, true = model(
                        batch,
                        cfg.optim.max_epoch,
                        mpnn_th,
                        msa_th,
                        ffn_th,
                        encoder_th,
                        pred_th,
                        global_th,
                    )
                else:
                    pred, true = model(batch)
                extra_stats = {}

            if cfg.dataset.name == "ogbg-code2":
                loss, pred_score = subtoken_cross_entropy(pred, true)
                _true = true
                _pred = pred_score
            else:
                loss, pred_score = compute_loss(pred, true)
                _true = true.detach().to("cpu", non_blocking=True)
                _pred = pred_score.detach().to("cpu", non_blocking=True)

            logger.update_stats(
                true=_true,
                pred=_pred,
                loss=loss.detach().cpu().item(),
                lr=0,
                time_used=time.time() - time_start,
                params=cfg.params,
                dataset_name=cfg.dataset.name,
                **extra_stats,
            )
            time_start = time.time()

            if ood_classes is not None:
                pred_classes = pred.argmax(dim=-1)
                total_preds += pred_classes.size(0)
                pred_classes = pred_classes.to(torch.device(cfg.accelerator))

                for pred_class in pred_classes:
                    if pred_class in ood_classes:
                        ood_matches += 1

    if ood_classes is not None:
        logging.info(f"OOD Predictions: {ood_matches*100/total_preds}")


@register_train("custom")
def custom_train(
    loggers, loaders, model, optimizer, scheduler, ood_classes=None
):
    """
    Customized training pipeline.

    Args:
        loggers: List of loggers
        loaders: List of loaders
        model: GNN model
        optimizer: PyTorch optimizer
        scheduler: PyTorch learning rate scheduler

    """
    start_epoch = 0
    if cfg.train.auto_resume:
        start_epoch = load_ckpt(
            model, optimizer, scheduler, cfg.train.epoch_resume
        )
    if start_epoch == cfg.optim.max_epoch:
        logging.info("Checkpoint found, Task already done")
    else:
        logging.info("Start from epoch %s", start_epoch)

    num_splits = len(loggers)
    split_names = ["val", "test"]
    full_epoch_times = []
    perf = [[] for _ in range(num_splits)]

    for cur_epoch in range(start_epoch, cfg.optim.max_epoch):
        torch.backends.cudnn.benchmark = True

        start_time = time.perf_counter()
        model.train()

        train_epoch(
            loggers[0],
            loaders[0],
            model,
            optimizer,
            scheduler,
            cfg.optim.batch_accumulation,
            cur_epoch,
        )
        perf[0].append(loggers[0].write_epoch(cur_epoch))

        if is_eval_epoch(cur_epoch):
            model.eval()
            if cfg.slt.pruning == "blockwise":
                if cfg.slt.mpnn:
                    mpnn_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "local_model",
                    )
                if cfg.slt.msa:
                    msa_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "self_attn",
                    )
                if cfg.slt.ffn:
                    ffn_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "ff_linear",
                    )
                if cfg.slt.encoder:
                    encoder_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "encoder",
                    )
                if cfg.slt.pred:
                    pred_th = get_threshold(
                        model,
                        cfg.slt.linear_sparsity,
                        cfg.optim.max_epoch,
                        "post_mp",
                    )
                global_th = None
            elif cfg.slt.pruning == "global":
                global_th = get_threshold(
                    model, cfg.slt.linear_sparsity, cfg.optim.max_epoch, ""
                )
                mpnn_th = None
                msa_th = None
                ffn_th = None
                encoder_th = None
                pred_th = None

            for i in range(1, num_splits):
                eval_epoch(
                    loggers[i],
                    loaders[i],
                    model,
                    cfg.optim.max_epoch,
                    mpnn_th,
                    msa_th,
                    ffn_th,
                    encoder_th,
                    pred_th,
                    global_th,
                    split=split_names[i - 1],
                    ood_classes=ood_classes,
                )
                perf[i].append(loggers[i].write_epoch(cur_epoch))
        else:
            for i in range(1, num_splits):
                perf[i].append(perf[i][-1])

        val_perf = perf[1]
        if cfg.optim.scheduler == "reduce_on_plateau":
            scheduler.step(val_perf[-1]["loss"])
        else:
            scheduler.step()
        full_epoch_times.append(time.perf_counter() - start_time)
        if is_eval_epoch(cur_epoch):
            best_epoch = np.array([vp["loss"] for vp in val_perf]).argmin()
            best_train = best_val = best_test = ""
            if cfg.metric_best != "auto":
                # Select again based on val perf of `cfg.metric_best`.
                m = cfg.metric_best
                best_epoch = getattr(
                    np.array([vp[m] for vp in val_perf]), cfg.metric_agg
                )()
                if m in perf[0][best_epoch]:
                    best_train = f"train_{m}: {perf[0][best_epoch][m]:.4f}"
                else:
                    best_train = f"train_{m}: {0:.4f}"
                best_val = f"val_{m}: {perf[1][best_epoch][m]:.4f}"
                best_test = f"test_{m}: {perf[2][best_epoch][m]:.4f}"
            if (
                cfg.train.enable_ckpt
                and cfg.train.ckpt_best
                and best_epoch == cur_epoch
            ):
                save_ckpt(model, optimizer, scheduler, cur_epoch)
                print("saved best model")
                if cfg.train.ckpt_clean:  # Delete old ckpt each time.
                    clean_ckpt()
            logging.info(
                f"> Epoch {cur_epoch}: took {full_epoch_times[-1]:.1f}s "
                f"(avg {np.mean(full_epoch_times):.1f}s) | "
                f"Best so far: epoch {best_epoch}\t"
                f"train_loss: {perf[0][best_epoch]['loss']:.4f} {best_train}\t"
                f"val_loss: {perf[1][best_epoch]['loss']:.4f} {best_val}\t"
                f"test_loss: {perf[2][best_epoch]['loss']:.4f} {best_test}"
            )
            if hasattr(model, "trf_layers"):
                # Log SAN's gamma parameter values if they are trainable.
                for li, gtl in enumerate(model.trf_layers):
                    if (
                        torch.is_tensor(gtl.attention.gamma)
                        and gtl.attention.gamma.requires_grad
                    ):
                        logging.info(
                            f"    {gtl.__class__.__name__} {li}: "
                            f"gamma={gtl.attention.gamma.item()}"
                        )

    logging.info(f"Avg time per epoch: {np.mean(full_epoch_times):.2f}s")
    logging.info(
        f"Total train loop time: {np.sum(full_epoch_times) / 3600:.2f}h"
    )
    for logger in loggers:
        logger.close()
    if cfg.train.ckpt_clean:
        clean_ckpt()

    logging.info("Task done, results saved in %s", cfg.run_dir)


@register_train("inference-only")
def inference_only(loggers, loaders, model, optimizer=None, scheduler=None):
    """
    Customized pipeline to run inference only.

    Args:
        loggers: List of loggers
        loaders: List of loaders
        model: GNN model
        optimizer: Unused, exists just for API compatibility
        scheduler: Unused, exists just for API compatibility
    """

    with torch.no_grad():
        num_splits = len(loggers)
        split_names = ["train", "val", "test"]
        perf = [[] for _ in range(num_splits)]
        cur_epoch = 0
        start_time = time.perf_counter()
        if cfg.slt.pruning == "blockwise":
            if cfg.slt.mpnn:
                mpnn_th = get_threshold(
                    model,
                    cfg.slt.linear_sparsity,
                    cfg.optim.max_epoch,
                    "local_model",
                )
            if cfg.slt.msa:
                msa_th = get_threshold(
                    model,
                    cfg.slt.linear_sparsity,
                    cfg.optim.max_epoch,
                    "self_attn",
                )
            if cfg.slt.ffn:
                ffn_th = get_threshold(
                    model,
                    cfg.slt.linear_sparsity,
                    cfg.optim.max_epoch,
                    "ff_linear",
                )
            if cfg.slt.encoder:
                encoder_th = get_threshold(
                    model,
                    cfg.slt.linear_sparsity,
                    cfg.optim.max_epoch,
                    "encoder",
                )
            if cfg.slt.pred:
                pred_th = get_threshold(
                    model,
                    cfg.slt.linear_sparsity,
                    cfg.optim.max_epoch,
                    "post_mp",
                )
            global_th = None
        elif cfg.slt.pruning == "global":
            global_th = get_threshold(
                model, cfg.slt.linear_sparsity, cfg.optim.max_epoch, ""
            )
            mpnn_th = None
            msa_th = None
            ffn_th = None
            encoder_th = None
            pred_th = None
        # for i in range(0, num_splits):
        eval_epoch(
            loggers[2],
            loaders[2],
            model,
            cur_epoch,
            mpnn_th,
            msa_th,
            ffn_th,
            encoder_th,
            pred_th,
            global_th,
            split=split_names[2],
        )
        perf[2].append(loggers[2].write_epoch(cur_epoch))

        best_epoch = 0
        # best_train = best_val = best_test = ""
        best_test = ""
        if cfg.metric_best != "auto":
            # # Select again based on val perf of `cfg.metric_best`.
            m = cfg.metric_best
            # if m in perf[0][best_epoch]:
            #     best_train = f"train_{m}: {perf[0][best_epoch][m]:.4f}"
            # else:
            #     # Note: For some datasets it is too expensive to compute
            #     # the main metric on the training set.
            #     best_train = f"train_{m}: {0:.4f}"
            # best_val = f"val_{m}: {perf[1][best_epoch][m]:.4f}"
            best_test = f"test_{m}: {perf[2][best_epoch][m]:.4f}"

        logging.info(
            f"> Inference | "
            # f"train_loss: {perf[0][best_epoch]['loss']:.4f} {best_train}\t"
            # f"val_loss: {perf[1][best_epoch]['loss']:.4f} {best_val}\t"
            f"test_loss: {perf[2][best_epoch]['loss']:.4f} {best_test}"
        )
        logging.info(f"Done! took: {time.perf_counter() - start_time:.2f}s")
        for logger in loggers:
            logger.close()


@register_train("PCQM4Mv2-inference")
def ogblsc_inference(loggers, loaders, model, optimizer=None, scheduler=None):
    """
    Customized pipeline to run inference on OGB-LSC PCQM4Mv2.

    Args:
        loggers: Unused, exists just for API compatibility
        loaders: List of loaders
        model: GNN model
        optimizer: Unused, exists just for API compatibility
        scheduler: Unused, exists just for API compatibility
    """
    from ogb.lsc import PCQM4Mv2Evaluator

    evaluator = PCQM4Mv2Evaluator()

    num_splits = 3
    split_names = ["valid", "test-dev", "test-challenge"]
    assert len(loaders) == num_splits, "Expecting 3 particular splits."

    # Check PCQM4Mv2 prediction targets.
    logging.info(f"0 ({split_names[0]}): {len(loaders[0].dataset)}")
    assert all([not torch.isnan(d.y)[0] for d in loaders[0].dataset])
    logging.info(f"1 ({split_names[1]}): {len(loaders[1].dataset)}")
    assert all([torch.isnan(d.y)[0] for d in loaders[1].dataset])
    logging.info(f"2 ({split_names[2]}): {len(loaders[2].dataset)}")
    assert all([torch.isnan(d.y)[0] for d in loaders[2].dataset])

    model.eval()
    for i in range(num_splits):
        all_true = []
        all_pred = []
        for batch in loaders[i]:
            batch.to(torch.device(cfg.accelerator))
            pred, true = model(batch)
            all_true.append(true.detach().to("cpu", non_blocking=True))
            all_pred.append(pred.detach().to("cpu", non_blocking=True))
        all_true, all_pred = torch.cat(all_true), torch.cat(all_pred)

        if i == 0:
            input_dict = {
                "y_pred": all_pred.squeeze(),
                "y_true": all_true.squeeze(),
            }
            result_dict = evaluator.eval(input_dict)
            logging.info(
                f"{split_names[i]}: MAE = {result_dict['mae']}"
            )  # Get MAE.
        else:
            input_dict = {"y_pred": all_pred.squeeze()}
            evaluator.save_test_submission(
                input_dict=input_dict,
                dir_path=cfg.run_dir,
                mode=split_names[i],
            )


@register_train("log-attn-weights")
def log_attn_weights(loggers, loaders, model, optimizer=None, scheduler=None):
    """
    Customized pipeline to inference on the test set and log the attention
    weights in Transformer modules.

    Args:
        loggers: Unused, exists just for API compatibility
        loaders: List of loaders
        model (torch.nn.Module): GNN model
        optimizer: Unused, exists just for API compatibility
        scheduler: Unused, exists just for API compatibility
    """
    import os.path as osp

    from torch_geometric.loader.dataloader import DataLoader

    from graphgps.utils import unbatch, unbatch_edge_index

    start_time = time.perf_counter()

    # The last loader is a test set.
    ld = loaders[-1]
    # To get a random sample, create a new loader that shuffles the test set.
    loader = DataLoader(
        ld.dataset,
        batch_size=ld.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )

    output = []
    # batch = next(iter(loader))  # Run one random batch.
    for b_index, batch in enumerate(loader):
        bsize = batch.batch.max().item() + 1  # Batch size.
        if len(output) >= 128:
            break
        print(f">> Batch {b_index}:")

        X_orig = unbatch(batch.x.cpu(), batch.batch.cpu())
        batch.to(torch.device(cfg.accelerator))
        model.eval()
        model(batch)

        # Unbatch to individual graphs.
        X = unbatch(batch.x.cpu(), batch.batch.cpu())
        edge_indices = unbatch_edge_index(
            batch.edge_index.cpu(), batch.batch.cpu()
        )
        graphs = []
        for i in range(bsize):
            graphs.append(
                {
                    "num_nodes": len(X[i]),
                    "x_orig": X_orig[i],
                    "x_final": X[i],
                    "edge_index": edge_indices[i],
                    "attn_weights": [],
                    # List with attn weights in layers from 0 to L-1.
                }
            )

        # Iterate through GPS layers and pull out stored attn weights.
        for l_i, (name, module) in enumerate(
            model.model.layers.named_children()
        ):
            if hasattr(module, "attn_weights"):
                print(l_i, name, module.attn_weights.shape)
                for g_i in range(bsize):
                    # Clip to the number of nodes in this graph.
                    # num_nodes = graphs[g_i]['num_nodes']
                    # aw = module.attn_weights[g_i, :num_nodes, :num_nodes]
                    aw = module.attn_weights[g_i]
                    graphs[g_i]["attn_weights"].append(aw.cpu())
        output += graphs

    logging.info(
        f"[*] Collected a total of {len(output)} graphs and their "
        f"attention weights for {len(output[0]['attn_weights'])} layers."
    )

    # Save the graphs and their attention stats.
    save_file = osp.join(cfg.run_dir, "graph_attn_stats.pt")
    logging.info(f"Saving to file: {save_file}")
    torch.save(output, save_file)

    logging.info(f"Done! took: {time.perf_counter() - start_time:.2f}s")
