r"""Training pipeline: training/evaluation structure, batch training.
"""
import datetime
import os
import shutil
from typing import Dict
from typing import Union
import random
import copy
from collections import defaultdict

import numpy as np
import torch
import torch.nn
import torch.nn.functional as F
from torch_scatter import scatter_mean, scatter_std, scatter_min, scatter_max, scatter_add
from munch import Munch
import torch_geometric
from torch_geometric.loader import DataLoader
from torch_geometric.data import Batch, Data, InMemoryDataset
from torch_geometric.utils import to_networkx, from_networkx, to_undirected, sort_edge_index, shuffle_node, is_undirected, contains_self_loops, contains_isolated_nodes, coalesce, subgraph, k_hop_subgraph
from tqdm import tqdm
import networkx as nx
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score as sk_roc_auc, f1_score, accuracy_score, precision_recall_fscore_support
from imblearn.under_sampling import RandomUnderSampler

from GOOD.ood_algorithms.algorithms.BaseOOD import BaseOODAlg
from GOOD.utils.args import CommonArgs
from GOOD.utils.evaluation import eval_data_preprocess, eval_score
from GOOD.utils.logger import pbar_setting
from GOOD.utils.register import register
from GOOD.utils.train import nan2zero_get_mask
from GOOD.utils.initial import reset_random_seed
import GOOD.kernel.pipelines.xai_metric_utils as xai_utils
from GOOD.networks.models.DIRGNN import split_graph_node

import wandb

pbar_setting["disable"] = True

class CustomDataset(InMemoryDataset):
    def __init__(self, root, samples, belonging, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        data_list = []
        attr_names = samples[1].keys()
        for i , G in enumerate(samples):
            if type(G) is nx.classes.digraph.DiGraph:
                data = from_networkx(G)
            else:
                data = copy.deepcopy(G)

            for attr_name in data.keys():
                if attr_name not in attr_names:
                    del data[attr_name]
                
            data.belonging = belonging[i]
            data.idx = i
            
            data_list.append(data)

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]
        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        self.data, self.slices = self.collate(data_list)


@register.pipeline_register
class Pipeline:
    r"""
    Kernel pipeline.

    Args:
        task (str): Current running task. 'train' or 'test'
        model (torch.nn.Module): The GNN model.
        loader (Union[DataLoader, Dict[str, DataLoader]]): The data loader.
        ood_algorithm (BaseOODAlg): The OOD algorithm.
        config (Union[CommonArgs, Munch]): Please refer to :ref:`configs:GOOD Configs and command line Arguments (CA)`.

    """

    def __init__(self, task: str, model: torch.nn.Module, loader: Union[DataLoader, Dict[str, DataLoader]],
                 ood_algorithm: BaseOODAlg,
                 config: Union[CommonArgs, Munch]):
        super(Pipeline, self).__init__()
        self.task: str = task
        self.model: torch.nn.Module = model
        self.loader: Union[DataLoader, Dict[str, DataLoader]] = loader
        self.ood_algorithm: BaseOODAlg = ood_algorithm
        self.config: Union[CommonArgs, Munch] = config

    def get_pretrain_targets(self, data):
        if len(data.y.shape) > 1:
            graph_label_per_node = data.y.view(-1)[data.batch] # contains for each node the label of the graph it belongs to
        else:
            graph_label_per_node = data.y[data.batch]

        targets = torch.full(
            size=(data.x.shape[0],),
            fill_value=0.,
            device=data.x.device
        )
        
        if self.config.train.pretrain == "degenerate":
            if self.config.dataset.dataset_name == "BAColorGVIsolated":
                # G iff R<B; V iff R>=B
                violet_nodes_for_negative = torch.logical_and(data.x[:, 3] == 1, graph_label_per_node == 0).float()
                green_nodes_for_positive = torch.logical_and(data.x[:, 2] == 1, graph_label_per_node == 1).float()                
                targets += green_nodes_for_positive + violet_nodes_for_negative
            elif self.config.dataset.dataset_name in ("MNIST", "CPatchMNIST", "CPatchMNIST2"):
                # predict the pixel number C (in rast-scan order) for each sample of class C
                # data.x[data.sp_order == 3, :3] = torch.tensor([1,1,1], device=data.x.device, dtype=data.x.dtype)
                # targets = (data.sp_order == graph_label_per_node).float()
                num_max_sp_per_batch = scatter_max(data.sp_order, index=data.batch)[0][data.batch]
                targets = torch.logical_or(
                    data.sp_order == graph_label_per_node,
                    data.sp_order == (num_max_sp_per_batch - graph_label_per_node)
                ).float()
            elif self.config.dataset.dataset_name == "MUTAG":
                carbon_nodes_for_nonmutag = torch.logical_and(data.node_type == 0, graph_label_per_node == 1).float()
                hydrogen_nodes_for_mutag = torch.logical_and(data.node_type == 3, graph_label_per_node == 0).float()
                targets += carbon_nodes_for_nonmutag + hydrogen_nodes_for_mutag
            elif self.config.dataset.dataset_name == "GraphSST2Planted":
                period_for_class0 = torch.logical_and(data.node_type == 1, graph_label_per_node == 0).float()
                comma_for_class1 = torch.logical_and(data.node_type == 2, graph_label_per_node == 1).float()
                targets += period_for_class0 + comma_for_class1
            else:
                raise ValueError(f"{self.config.dataset.dataset_name} not supported for pretrain")
        elif self.config.train.pretrain == "suff":
            if self.config.dataset.dataset_name == "BAColorGVIsolated":
                # Highlight all red nodes for class 0, and all blue nodes for class 1. Easy, but suboptimal
                blue_nodes_for_positive = torch.logical_and(data.x[:, 1] == 1, graph_label_per_node == 1).float()
                red_nodes_for_negative = torch.logical_and(data.x[:, 0] == 1, graph_label_per_node == 0).float()
                targets += blue_nodes_for_positive + red_nodes_for_negative
            elif self.config.dataset.dataset_name in ("MNIST", "CPatchMNIST", "CPatchMNIST2"):
                # just pick nodes labelled as "ground truth"
                # targets = data.node_label.float()
                
                # make sure to include the entire 1-hop neighbors from white nodes
                subset, _, _, _ = k_hop_subgraph(
                    node_idx=torch.nonzero(data.x[:, :3].min(1).values > 0.3).view(-1),
                    num_hops=1,
                    edge_index=data.edge_index,
                    num_nodes=data.x.shape[0]
                )
                targets[subset] = 1.0
            else:
                raise ValueError(f"{self.config.dataset.dataset_name} not supported for pretrain")
        elif self.config.train.pretrain == "sub":
            if self.config.dataset.dataset_name == "BAColorRBIsolated":
                # 1R when R>=B; 1B when R<B
                blue_nodes_for_positive = torch.logical_and(data.x[:, 1] == 1, graph_label_per_node == 1)
                red_nodes_for_negative = torch.logical_and(data.x[:, 0] == 1, graph_label_per_node == 0)
                # Select only the isolated R/B node
                blue_isol_node_for_positive = torch.logical_and(blue_nodes_for_positive, data.node_is_spurious).float()
                red_isol_node_for_positive = torch.logical_and(red_nodes_for_negative, data.node_is_spurious).float()
                targets += blue_isol_node_for_positive + red_isol_node_for_positive
            else:
                raise ValueError(f"{self.config.dataset.dataset_name} not supported for pretrain")
        else:
            assert False
        return targets


    def pretrain_model(self, loader: DataLoader, val_loader: DataLoader) -> dict:
        performance_bar_det_loss = 100
        if self.config.train.pretrain == "suff":
            if self.config.dataset.dataset_name == "MNIST":
                performance_bar = 0.95
                performance_bar_clf_loss = 0.08
            elif self.config.dataset.dataset_name == "CPatchMNIST" or self.config.dataset.dataset_name == "CPatchMNIST2":
                performance_bar = 0.99
                performance_bar_clf_loss = 0.01
                performance_bar_det_loss = 0.005 # before 0.1
            else:
                performance_bar = 0.95
                performance_bar_clf_loss = 0.01
        elif self.config.train.pretrain == "degenerate":
            if self.config.dataset.dataset_name in ("MNIST", "CPatchMNIST", "CPatchMNIST2"):
                if self.config.model.model_name == "DIR":
                    performance_bar = 0.98
                    performance_bar_clf_loss = 0.01
                else:
                    performance_bar = 0.95
                    performance_bar_clf_loss = 0.08 #0.03
            elif self.config.dataset.dataset_name == "MUTAG":
                performance_bar = 0.95
                performance_bar_clf_loss = 0.08 #0.03
            elif self.config.dataset.dataset_name == "GraphSST2Planted":
                if self.config.model.model_name == "DIR":
                    performance_bar = 0.99 # 0.98 is bcp_* checkpoints saved
                    performance_bar_clf_loss = 0.015
                else:
                    performance_bar = 0.96
                    performance_bar_clf_loss = 0.015
            else:
                if self.config.model.model_name == "DIR":
                    performance_bar = 0.99
                    performance_bar_clf_loss = 0.011 # otherwise training gets stuck at 0.011
                else:
                    performance_bar = 0.99
                    performance_bar_clf_loss = 0.01
        elif self.config.train.pretrain == "sub":            
            if self.config.model.model_name == "DIR":
                performance_bar = 0.98
                performance_bar_clf_loss = 0.018 # otherwise training gets stuck at 0.017
                performance_bar_det_loss = 0.01
            else:
                performance_bar = 0.98
                performance_bar_clf_loss = 0.01
                performance_bar_det_loss = 0.01
        else:
            assert False

        f1_pos_epoch, f1_neg_epoch, acc_epoch, clf_loss_epoch, det_loss_epoch = 0, 0, 0, 10, 10
        f1_pos_epoch_val, f1_neg_epoch_val, acc_epoch_val, clf_loss_epoch_val, det_loss_epoch = 0, 0, 0, 10, 10
        epoch = -1
        while min(f1_pos_epoch, f1_neg_epoch, acc_epoch) < performance_bar or clf_loss_epoch > performance_bar_clf_loss or det_loss_epoch > performance_bar_det_loss:
            if epoch >= 1500:
                print("\n\nReached max number of epochs. Stopping pretraining.\n\n")
                break
            epoch += 1
            self.config.train.epoch = epoch
            print(f'\nEpoch {epoch}:')

            per_batch_metrics = defaultdict(list)
            pbar = tqdm(enumerate(loader), total=len(loader), **pbar_setting)
            for index, data in pbar:                
                self.ood_algorithm.optimizer.zero_grad()
                data = data.to(self.config.device)
                mask, targets = nan2zero_get_mask(data, 'train', self.config)
                node_norm = data.get('node_norm') if self.config.model.model_level == 'node' else None
                data, _, mask, _ = self.ood_algorithm.input_preprocess(
                    data,
                    targets,
                    mask,
                    node_norm,
                    self.model.training,
                    self.config
                )

                model_output = self.model(
                    data=data,
                    edge_weight=None,
                    ood_algorithm=self.ood_algorithm,
                    max_num_epoch=self.config.train.max_epoch,
                    curr_epoch=epoch,
                    pretrain=False
                )

                # Train the classifier                
                raw_pred = self.ood_algorithm.output_postprocess(model_output)       
                clf_loss = self.ood_algorithm.loss_classifier(raw_pred, targets, mask, node_norm, self.config, batch=data.batch).sum() / mask.sum()

                node_att = self.ood_algorithm.att.sigmoid()
                
                targets = self.get_pretrain_targets(data)

                uninformative_value = 0 #self.config.ood.extra_param[2] if "GSAT" in self.config.model.model_name else 0.
                targets[targets == 0] = uninformative_value
                
                # Weighted Cross-Entropy Loss
                loss_weight = targets.clone()
                loss_weight[targets == 0] = 1
                loss_weight[targets == 1] = 100 # TODO: 10 for BAColor; 100 for others
                detector_loss = F.binary_cross_entropy(node_att.squeeze(1), targets, weight=loss_weight)

                self.ood_algorithm.backward(detector_loss + clf_loss)
                
                pred, target = eval_data_preprocess(data.y, raw_pred, mask, self.config)
                task_score = eval_score([pred], [target], self.config, pos_class=self.loader["train"].dataset.minority_class)
                f1_pos = f1_score(
                    (targets > uninformative_value).cpu().numpy(),
                    (node_att.squeeze(1) > 0.9).cpu().numpy(),
                    average='binary',
                    pos_label=1
                )
                f1_neg = f1_score(
                    (targets > uninformative_value).cpu().numpy(),
                    (node_att.squeeze(1) > 0.1).cpu().numpy(),
                    average='binary',
                    pos_label=0
                )
                per_batch_metrics["loss"].append(detector_loss.item())
                per_batch_metrics["f1_pos"].append(f1_pos)
                per_batch_metrics["f1_neg"].append(f1_neg)
                per_batch_metrics["task_score"].append(task_score)
                per_batch_metrics["clf_loss"].append(clf_loss.item())
            f1_pos_epoch = np.mean(per_batch_metrics['f1_pos'])
            f1_neg_epoch = np.mean(per_batch_metrics['f1_neg'])
            acc_epoch = np.mean(per_batch_metrics['task_score'])
            clf_loss_epoch = np.mean(per_batch_metrics['clf_loss'])
            det_loss_epoch = np.mean(per_batch_metrics['loss'])

            # --- scheduler step ---
            self.ood_algorithm.scheduler.step()

        epoch_train_stat = self.evaluate(
            'eval_train',
            compute_plaus=False,
            epoch=self.config.train.max_epoch
        )
        id_val_stat = self.evaluate('id_val', epoch=self.config.train.max_epoch)
        id_test_stat = self.evaluate('id_test', epoch=self.config.train.max_epoch)
        val_stat = id_val_stat
        test_stat = id_test_stat
        loss_per_batch_dict = {}
        
        self.save_epoch(
            self.config.train.max_epoch,
            epoch_train_stat, id_val_stat, id_test_stat, val_stat, test_stat,
            self.config,
            loss_per_batch_dict,
        )
        return None

    def train_batch(self, data: Batch, pbar, epoch:int) -> dict:
        r"""
        Train a batch. (Project use only)

        Args:
            data (Batch): Current batch of data.

        Returns:
            Calculated loss.
        """
        data = data.to(self.config.device)

        self.ood_algorithm.optimizer.zero_grad()

        mask, targets = nan2zero_get_mask(data, 'train', self.config)
        node_norm = data.get('node_norm') if self.config.model.model_level == 'node' else None
        data, targets, mask, node_norm = self.ood_algorithm.input_preprocess(data, targets, mask, node_norm,
                                                                             self.model.training,
                                                                             self.config)
        edge_weight = data.get('edge_norm') if self.config.model.model_level == 'node' else None

        model_output = self.model(
            data=data,
            edge_weight=edge_weight,
            ood_algorithm=self.ood_algorithm,
            max_num_epoch=self.config.train.max_epoch,
            curr_epoch=epoch
        )

        raw_pred = self.ood_algorithm.output_postprocess(model_output)
        
        loss = self.ood_algorithm.loss_calculate(raw_pred, targets, mask, node_norm, self.config, batch=data.batch)
        loss = self.ood_algorithm.loss_postprocess(loss, data, mask, self.config, epoch)

        self.ood_algorithm.backward(loss)
        
        pred, target = eval_data_preprocess(data.y, raw_pred, mask, self.config)

        return {
            'loss': loss.detach(),
            'score': eval_score([pred], [target], self.config, pos_class=self.loader["train"].dataset.minority_class), 
            'clf_loss': self.ood_algorithm.clf_loss,
            'l_norm_loss': float(self.ood_algorithm.l_norm_loss),
            'entr_loss': float(self.ood_algorithm.entr_loss),
            'spec_loss': float(self.ood_algorithm.spec_loss),
            'mean_loss': float(self.ood_algorithm.mean_loss),
            'total_loss': float(self.ood_algorithm.total_loss),
        }


    def train(self):
        r"""
        Training pipeline.
        """
        if self.config.wandb:
            wandb.login()

        # config model
        print('Config model')
        self.config_model('train')

        # Load training utils
        print('Load training utils')
        self.ood_algorithm.set_up(self.model, self.config)

        if self.config.train.pretrain:
            ##
            # Attack models by pretraining them only to output degenerate explanations
            ##
            print("#IM#Pretraining model for degenerate explanations")
            self.pretrain_model(self.loader['train'], self.loader['id_val'])
            print("#IM#End of pretraining")
            return

        print("Before training:")
        epoch_train_stat = self.evaluate('eval_train', epoch=0)
        id_val_stat = self.evaluate('id_val', epoch=0)
        id_test_stat = self.evaluate('id_test', epoch=0)

        if self.config.wandb:
            wandb.log({
                    "epoch": -1,
                    "all_train_loss": epoch_train_stat["loss"],
                    "all_id_val_loss": id_val_stat["loss"],
                    "train_score": epoch_train_stat["score"],
                    "id_val_score": id_val_stat["score"],
                    "id_test_score": id_test_stat["score"],
                    "val_score": np.nan,
                    "test_score": np.nan,
                },
                step=0
            )

        # train the model
        counter = 1
        for epoch in range(self.config.train.ctn_epoch, self.config.train.max_epoch):
            self.config.train.epoch = epoch
            print(f'\nEpoch {epoch}:')

            self.ood_algorithm.stage_control(self.config)

            pbar = tqdm(enumerate(self.loader['train']), total=len(self.loader['train']), **pbar_setting)
            edge_scores = []
            train_batch_score = []
            loss_per_batch_dict = defaultdict(list)
            for index, data in pbar:
                if data.batch is not None and (data.batch[-1] < self.config.train.train_bs - 1):
                    continue

                # train a batch
                train_stat = self.train_batch(data, pbar, epoch)

                # log stats
                train_batch_score.append(train_stat["score"])
                for l in ("mean_loss", "spec_loss", "total_loss", "entr_loss", "l_norm_loss", "clf_loss"):
                    loss_per_batch_dict[l].append(train_stat.get(l, np.nan)) 

                if self.config.model.model_name != "GIN":
                    edge_scores.append(self.ood_algorithm.edge_att.detach().cpu())                                  
            
            for l in ("mean_loss", "spec_loss", "total_loss", "entr_loss", "l_norm_loss", "clf_loss"):
                loss_per_batch_dict[l] = np.mean(loss_per_batch_dict[l])

            # Epoch val
            print('Evaluating...')
            print(f"Clf loss: {loss_per_batch_dict['clf_loss']:.4f}")
            print(f"Spec loss: {loss_per_batch_dict['spec_loss']:.4f}")
            print(f"Mean loss: {loss_per_batch_dict['mean_loss']:.4f}")
            print(f"Total loss: {loss_per_batch_dict['total_loss']:.4f}")

            epoch_train_stat = self.evaluate(
                'eval_train',
                compute_plaus=False,
                epoch=epoch
            )
            id_val_stat = self.evaluate('id_val', epoch=epoch)
            id_test_stat = self.evaluate('id_test', epoch=epoch)
            
            if self.config.dataset.shift_type == "no_shift":
                val_stat = id_val_stat
                test_stat = id_test_stat
            else:
                val_stat = id_val_stat
                test_stat = id_test_stat

            if self.config.model.model_name != "GIN":
                tmp = torch.cat(edge_scores, dim=0)
                print("edge_weight: ", tmp.min(), tmp.max(), tmp.mean())

            if self.config.wandb:
                edge_scores = torch.cat(edge_scores, dim=0)
                log_dict = {
                    "epoch": epoch,
                    "clf_loss": loss_per_batch_dict["clf_loss"],
                    "mean_loss": loss_per_batch_dict["mean_loss"],
                    "spec_loss": loss_per_batch_dict["spec_loss"],
                    "total_loss": loss_per_batch_dict["total_loss"],
                    "l_norm_loss": loss_per_batch_dict["l_norm_loss"],
                    "entr_loss": loss_per_batch_dict["entr_loss"],
                    "all_train_loss": epoch_train_stat["loss"],
                    "all_id_val_loss": id_val_stat["loss"],
                    "train_batch_score": np.mean(train_batch_score),
                    "train_score": epoch_train_stat["score"],
                    "id_val_score": id_val_stat["score"],
                    "id_test_score": id_test_stat["score"],
                    "val_score": val_stat["score"],
                    "test_score": test_stat["score"],
                    "edge_weight": wandb.Histogram(sequence=edge_scores, num_bins=100),
                    "wiou": epoch_train_stat["wiou"],
                }
                wandb.log(log_dict, step=counter)
                counter += 1

            # checkpoints save
            self.save_epoch(epoch, epoch_train_stat, id_val_stat, id_test_stat, val_stat, test_stat, self.config, loss_per_batch_dict)

            # --- scheduler step ---
            self.ood_algorithm.scheduler.step() 


    @torch.no_grad()
    def evaluate_graphs(self, loader, clfonly, log=False, **kwargs):
        pbar = tqdm(loader, desc=f'Eval intervened graphs', total=len(loader), **pbar_setting)
        preds_eval, belonging = [], []
        for data in pbar:
            data: Batch = data.to(self.config.device)
            if clfonly:
                output = self.model.predict_from_subgraph(data=data, edge_weight=None, edge_attn=None, node_att=data.node_expl.unsqueeze(1), ood_algorithm=self.ood_algorithm, **kwargs)
            else:                
                if log:
                    output = self.model.log_probs(data=data, edge_weight=None, ood_algorithm=self.ood_algorithm, **kwargs)
                else:
                    output = self.model.probs(data=data, edge_weight=None, ood_algorithm=self.ood_algorithm, **kwargs)
            preds_eval.extend(output.detach().cpu().numpy().tolist())
            belonging.extend(data.belonging.detach().cpu().numpy().tolist())
        preds_eval = torch.tensor(preds_eval)
        belonging = torch.tensor(belonging, dtype=int)
        return preds_eval, belonging


    @torch.no_grad()
    def generate_binary_explanations(self, is_weight, thrs, splits, convert_to_nx, is_node_expl):
        assert is_node_expl
        
        reset_random_seed(self.config)
        self.model.eval()

        samples = {
            split: 
                {thr: [] for thr in thrs} 
            for split in splits
        }
        avg_graph_size = {}
        graphs_nx = {}
        for split in splits:
            dataset = self.get_local_dataset(split)
            loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=2)            
            for data in loader:
                data: Batch = data.to(self.config.device)   
                edge_scores, node_scores, logits = self.model.get_subgraph(
                    data=data,
                    edge_weight=None,
                    ood_algorithm=self.ood_algorithm,
                    do_relabel=False
                )

                for j, g in enumerate(data.to_data_list()):
                    for thr in thrs:
                        new_g = copy.deepcopy(g)
                        new_g.y_pred = logits[j].argmax(0) if logits[j].shape[-1] > 1 else int(logits[j] > 0.5)

                        if new_g.y_pred != new_g.y:
                            continue

                        if is_node_expl:
                            node_expl = node_scores[data.batch == j].squeeze(1)

                            # normalize explanation scores in [0,1]
                            # node_expl = (node_expl - node_expl.min()) / (node_expl.max() - node_expl.min())

                            new_g.node_expl = node_expl

                            # compute binary node mask based on threshold here for convenience
                            new_g.node_mask = new_g.node_expl >= thr

                            if self.config.model.model_name == "DIR":
                                # remove nodes not in the TopK
                                (causal_x, causal_edge_index, causal_edge_attr, causal_batch, causal_node_weight), \
                                    (conf_x, conf_edge_index, conf_edge_attr, conf_batch, conf_node_weight), \
                                        (topK_nodes_kept, topK_nodes_removed) = split_graph_node(g, new_g.node_expl, self.config.ood.ood_param, embed=None, use_input_feat=True)
                                assert topK_nodes_kept.shape[0] + topK_nodes_removed.shape[0] == g.x.shape[0]
                                new_g.node_mask[topK_nodes_removed] = False
                            
                            # compute binary edge mask from previous node mask
                            # take the node induced subgraph as topological explanation
                            _, _, new_g.edge_mask = subgraph(new_g.node_mask, new_g.edge_index, None, return_edge_mask=True)
                        else:
                            raise ValueError("only node expl for now")
                        
                        samples[split][thr].append(new_g.to("cpu"))
            
            avg_graph_size[split] = np.mean([g.edge_index.shape[1] for g in samples[split][thrs[0]]])

            if convert_to_nx:
                print("Converting graphs to networkx")
                edge_attr_tokeep = [s for s in ["edge_attr", "edge_gt", "edge_attr"] if s in g.keys()]
                graphs_nx[split] = [to_networkx(g, node_attrs=["x", "node_is_spurious"], edge_attrs=edge_attr_tokeep or None) for g in samples[split][thrs[0]]]
            else:
                graphs_nx[split] = list()
        return samples, graphs_nx, avg_graph_size


    @torch.no_grad()
    def compute_metric(
        self,
        metric: str,
        graphs,
        graphs_nx,
        avg_graph_size,
        log_info=True
    ):
        if log_info:
            print(f"\n\n", "-"*50)
        reset_random_seed(self.config)
        self.model.eval()   

        scores, acc_ints = defaultdict(list), []

        eval_samples, belonging, reference = [], [], []
        preds_ori, labels_ori, expl_acc_ori = [], [], []
        graph_database_labels = torch.tensor([g.y.item() for g in graphs], device=graphs[0].y.device)

        pbar = tqdm(range(len(graphs[:])), desc=f'Creating Intervent. distrib.', total=len(graphs), **pbar_setting)
        for i in pbar:
            if metric == "fidm" or metric == "fidp":
                intervened_graphs = xai_utils.fidelity(
                    graphs[i],
                    type=metric
                )
            elif metric == "rfidm" or metric == "rfidp":
                intervened_graphs = xai_utils.robust_fidelity(
                    graphs[i],
                    type=metric,
                    p=self.config.rfid_alpha_1 if metric == "rfidp" else self.config.rfid_alpha_2,
                    expval_budget=self.config.expval_budget
                )
            elif metric == "nec":
                intervened_graphs = xai_utils.nec_budget(
                    graphs[i],
                    avg_graph_size=avg_graph_size,
                    p=self.config.nec_budget,
                    expval_budget=self.config.expval_budget
                )
            elif metric == "suff":
                intervened_graphs = xai_utils.suff_intervent(
                    graphs[i],
                    graph_database=graphs,
                    graph_database_labels=graph_database_labels,
                    expval_budget=self.config.expval_budget
                )
            elif metric == "counter_fid":
                intervened_graphs = xai_utils.counter_fid(
                    graphs[i],
                    expval_budget=self.config.expval_budget
                )
            elif metric == "suff_cause":
                intervened_graphs = xai_utils.suff_cause(
                    graphs[i],
                    expval_budget=self.config.expval_budget
                )
            else:
                raise ValueError(f"Metric {metric} not supported")

            if intervened_graphs is not None:
                eval_samples.append(graphs[i])
                reference.append(len(eval_samples) - 1)
                belonging.append(-1)
                labels_ori.append(graphs[i].y)
                belonging.extend([i] * len(intervened_graphs))
                eval_samples.extend(intervened_graphs)

        if len(eval_samples) <= 1:
            print(f"\nToo few intervened samples, skipping this")
            exit()
            scores["all_KL"].append(1.0)
            scores["all_L1"].append(1.0)
            scores["rejection"].append(np.nan)
            acc_ints.append(-1.0)
            return scores, None
        
        ##
        # Compute new predictions
        ##
        int_dataset = CustomDataset(root=None, samples=eval_samples, belonging=belonging)
        loader = DataLoader(int_dataset, batch_size=256, shuffle=False)
        preds_eval, belonging = self.evaluate_graphs(loader, log=False, clfonly=metric in ["counter_fid"])

        preds_clean_graphs = preds_eval[reference]
        
        mask = torch.ones(preds_eval.shape[0], dtype=bool)
        mask[reference] = False
        preds_perturbed_graphs = preds_eval[mask]
        belonging = belonging[mask]            
        assert torch.all(belonging >= 0), f"{torch.all(belonging >= 0)}"

        num_perturbation_per_sample = 1 if metric in ("fidm", "fidp") else self.config.expval_budget
        labels_ori = torch.tensor(labels_ori)
        preds_clean_graphs_repeated = preds_clean_graphs.repeat_interleave(num_perturbation_per_sample, dim=0)
        labels_ori_repeated = labels_ori.repeat_interleave(num_perturbation_per_sample, dim=0)

        ##
        # Compute metric values
        ##
        aggr = self.get_aggregated_metric(
            metric,
            preds_clean_graphs_repeated,
            preds_perturbed_graphs,
            belonging
        )
        
        ##
        # Store and print metric values
        ##
        for m in ["TV", "predicted"]:
            # for c in labels_ori.long().unique():
            for c in graph_database_labels.long().unique():
                idx_class = np.arange(labels_ori.shape[0])[(labels_ori == c).numpy()]
                if len(idx_class) <= 10:
                    scores[f"{c.item()}_{m}"].append(np.nan)
                else:    
                    scores[f"{c.item()}_{m}"].append(round(aggr[m][idx_class].mean().item(), 3))
                # print(f"Class {c} rej={aggr['rejection'][idx_class].float().mean()}")
            scores[f"all_{m}"].append(round(aggr[m].mean().item(), 3))
        scores[f"rejection"].append(round(aggr["rejection"].float().mean().item(), 3))

        acc_clean = eval_score(preds_clean_graphs_repeated, labels_ori_repeated, self.config, self.loader["id_val"].dataset.minority_class)
        acc_interven = eval_score(preds_perturbed_graphs, labels_ori_repeated, self.config, self.loader["id_val"].dataset.minority_class)
        acc_ints.append(acc_interven.item())

        if log_info:
            print()
            print(f"Label distrib: {labels_ori.unique(return_counts=True)}")
            print(f"Acc clean", round(acc_clean.item(), 3))
            print(f"Acc interven", round(acc_interven.item(), 3))
            print(f"len(reference) = {len(reference)}")
            for m in ["TV", "predicted"]:
                for c in labels_ori.long().unique().numpy().tolist():
                    print(f"{metric.upper()} for class {c}_{m} = {scores[str(c)+'_'+m][-1]} +- {aggr['predicted'].std():.3f} (in-sample avg dev_std = {(aggr['std_predicted']**2).mean().sqrt():.3f})")
                print(f"{metric.upper()} all_classes {m} = {scores[f'all_{m}'][-1]} +- {aggr[f'{m}'].std():.3f} (in-sample avg dev_std =", torch.round((aggr[f"std_{m}"]**2).mean().sqrt(), decimals=3).item())
            print(f"{metric.upper()} rejection = {scores[f'rejection'][-1]}")
        return scores, acc_ints


    def normalize_belonging(self, belonging):
        #TODO: make more efficient
        ret = []
        i = -1
        for j , elem in enumerate(belonging):
            if len(ret) > 0 and elem == belonging[j-1]:
                ret.append(i)
            else:
                i += 1
                ret.append(i)
        return ret

    def get_aggregated_metric(self, metric, preds_clean, preds_perturb, belonging):
        ret = {}
        belonging = torch.tensor(self.normalize_belonging(belonging))

        div_TV = torch.abs(preds_clean - preds_perturb).sum(-1)
        if preds_clean.shape[1] == 1:
            div_predicted = torch.abs(preds_clean - preds_perturb)
        else:
            pred_class = preds_clean.argmax(-1).unsqueeze(1)            
            div_predicted = torch.abs(
                preds_clean.gather(1, pred_class) - preds_perturb.gather(1, pred_class)
            )

        # average across expval_budget
        if metric in ["suff_cause"]:
            ret["TV"], _ = scatter_max(div_TV, belonging, dim=0)
            ret["predicted"], _ = scatter_max(div_predicted, belonging, dim=0)
        else:
            ret["TV"] = scatter_mean(div_TV, belonging, dim=0)
            ret["predicted"] = scatter_mean(div_predicted, belonging, dim=0)
        
        # Add div_rejection: whether the model prediction changed
        if preds_clean.shape[1] == 1:
            pred_class_clean = (preds_clean > 0.5)
            pred_class_pert = (preds_perturb > 0.5)

            uncertain_clean = (preds_clean >= 0.40) & (preds_clean <= 0.60)
            uncertain_pert = (preds_perturb >= 0.40) & (preds_perturb <= 0.60)
            uncertain_to_penalize = uncertain_pert & (~uncertain_clean)
            # Force uncertain perturbed predictions for certain clean predictions to be treated as incorrect
            pred_class_pert[uncertain_to_penalize] = ~pred_class_clean[uncertain_to_penalize]        
            div_rejection = (pred_class_clean != pred_class_pert).long() # is 1 when predictions are different
        else:
            pred_class_clean = preds_clean.argmax(-1).unsqueeze(1)
            pred_class_pert = preds_perturb.argmax(-1).unsqueeze(1)

            if preds_clean.shape[1] == 2: # sill binary classification, but for DIR
                uncertain_clean = (preds_clean.gather(1, pred_class_clean) >= 0.40) & (preds_clean.gather(1, pred_class_clean) <= 0.60)
                uncertain_pert = (preds_perturb.gather(1, pred_class_pert) >= 0.40) & (preds_perturb.gather(1, pred_class_pert) <= 0.60)
                uncertain_to_penalize = uncertain_pert & (~uncertain_clean)
                # Force uncertain perturbed predictions for certain clean predictions to be treated as incorrect
                pred_class_pert[uncertain_to_penalize] = ~pred_class_clean[uncertain_to_penalize]        
            div_rejection = (pred_class_clean != pred_class_pert).long() # is 1 when predictions are different

        # note that rejection is the worst-case for every metric (maybe fix?)
        ret["rejection"], _ = scatter_max(div_rejection, belonging, dim=0) # if the predictions changed at least one across expval_budget

        ret["std_TV"] = scatter_std(div_TV, belonging, dim=0)
        ret["std_predicted"] = scatter_std(div_predicted, belonging, dim=0)
        return ret

    def get_local_dataset(self, split, log=True):
        if torch_geometric.__version__ == "2.4.0" and log:
            print(self.loader[split].dataset, "for split ", split)
            print(f"Data example from {split}: {self.loader[split].dataset.get(0)}")
            print(f"Label distribution from {split}: {self.loader[split].dataset.y.unique(return_counts=True)}")        

        dataset = self.loader[split].dataset
        
        if abs(dataset.y.unique(return_counts=True)[1].min() - dataset.y.unique(return_counts=True)[1].max()) > 1000:
            print(f"#D#Unbalanced warning for {self.config.dataset.dataset_name} ({split})")
        
        if "hiv" in self.config.dataset.dataset_name.lower() and str(self.config.numsamples_budget) != "all":
            balanced_idx, _ = RandomUnderSampler(random_state=42).fit_resample(np.arange(len(dataset)).reshape(-1,1), dataset.y)

            dataset = dataset[balanced_idx.reshape(-1)]
            print(f"Creating balanced dataset: {dataset.y.unique(return_counts=True)}")
        return dataset

    @torch.no_grad()
    def evaluate(self, split: str, epoch:int, compute_plaus=False, compute_mcc=False, compute_clf_only_pred=False):
        r"""
        This function is design to collect data results and calculate scores and loss given a dataset subset.
        (For project use only)

        Args:
            split (str): A split string for choosing the corresponding dataloader. Allowed: 'train', 'id_val', 'id_test',
                'val', and 'test'.

        Returns:
            A score and a loss.

        """
        assert not (compute_plaus and compute_mcc)
        
        stat = {'score': None, 'loss': None, 'wiou': None}
        if self.loader.get(split) is None:
            return stat
        
        was_training = self.model.training
        self.model.eval()

        loss_per_batch_dict = defaultdict(list)
        mask_all = []
        pred_all = []
        pred_clf_only_all = []
        target_all = []
        likelihoods_all = []
        wious_all, aucroc_all, f1_pos_all, f1_neg_all = [], [], [], []
        recall_pos_all, prec_pos_all, recall_neg_all, prec_neg_all = [], [], [], []
        pbar = tqdm(self.loader[split], desc=f'Eval {split.capitalize()}', total=len(self.loader[split]),
                    **pbar_setting)
        for data in pbar:
            data: Batch = data.to(self.config.device)            

            mask, targets = nan2zero_get_mask(data, split, self.config)            
            if mask is None:
                return stat
            
            node_norm = data.get('node_norm') if self.config.model.model_level == 'node' else None

            data, targets, mask, node_norm = self.ood_algorithm.input_preprocess(
                data,
                targets,
                mask,
                node_norm,
                self.model.training,
                self.config
            )
            
            model_output = self.model(
                data=data,
                edge_weight=None,
                ood_algorithm=self.ood_algorithm
            )
            
            if compute_clf_only_pred:
                clf_only_output = self.model.predict_from_subgraph(
                    data=data,
                    edge_att=torch.ones((data.edge_index.shape[1]), device=data.x.device),
                    node_att=torch.ones((data.x.shape[0],1), device=data.x.device)
                ).squeeze(-1)
                pred_clf_only_all.append(clf_only_output.cpu().numpy())

            # --------------- Loss collection ------------------
            raw_preds = self.ood_algorithm.output_postprocess(model_output)
            loss = self.ood_algorithm.loss_calculate(raw_preds, targets, mask, node_norm, self.config, batch=data.batch)
            loss = self.ood_algorithm.loss_postprocess(loss, data, mask, self.config, epoch)


            mask_all.append(mask)
            # loss_all.append(loss.item())
            for l in ("spec_loss", "entr_loss", "l_norm_loss", "clf_loss", "total_loss"):
                loss_per_batch_dict[l].append(float(getattr(self.ood_algorithm, l, np.nan))) 

            # ------------- Likelihood data collection ------------------
            if raw_preds.shape[-1] > 1:
                probs = raw_preds.softmax(dim=1)
                likelihoods_all.append(probs.gather(1, targets.unsqueeze(1)))
            else:
                probs = raw_preds.sigmoid()
                likelihoods_all.append(torch.full_like(probs, fill_value=-1))            

            # ------------- Score data collection ------------------
            pred, target = eval_data_preprocess(data.y, raw_preds, mask, self.config)
            pred_all.append(pred)
            target_all.append(target)

            # ------------- PLAUSIBILITY ------------------
            if compute_plaus:
                for j, g in enumerate(data.to_data_list()):
                    if self.config.dataset.dataset_name == "MNIST":
                        gt = g.node_label.cpu().numpy()
                    elif self.config.dataset.dataset_name == "BAColorGVIsolated":
                        gt = g.node_is_spurious.cpu().numpy()
                        if np.all(gt == 1) or np.all(gt == 0): # sk_roc_auc breaks as only 1 class is present
                            continue
                    else:
                        gt = None

                    node_expl = self.ood_algorithm.edge_att[data.batch == j].detach().cpu().squeeze(-1).numpy()
                    aucroc_all.append(
                        sk_roc_auc(gt, node_expl, average="macro")
                    )

            # ------------- PLAUSIBILITY WRT INDUCED-DEGENERATE EXPLANATIONS ------------------
            if compute_mcc: 
                if raw_preds.shape[-1] > 1:
                    preds = raw_preds.argmax(dim=1)
                else:
                    preds = (raw_preds > 0.5).long().view(-1)

                # extract for each node the label of the graph it belongs to
                if len(data.y.shape) > 1:
                    graph_label_per_node = data.y.view(-1)[data.batch]
                else:
                    graph_label_per_node = data.y[data.batch]

                correct_nodes = graph_label_per_node == preds[data.batch]

                node_att = self.ood_algorithm.att.sigmoid()[correct_nodes] # take att_log_logit of each model
                targets = self.get_pretrain_targets(data)[correct_nodes]                 

                
                prec_pos, recall_pos, f1_pos, _ = precision_recall_fscore_support(
                    (targets > 0).cpu().numpy(),
                    (node_att.squeeze(1) > 0.9).cpu().numpy(),
                    average='binary',
                    pos_label=1,
                )                 
                prec_neg, recall_neg, f1_neg, _ = precision_recall_fscore_support(
                    (targets > 0).cpu().numpy(),
                    (node_att.squeeze(1) > 0.1).cpu().numpy(),
                    average='binary',
                    pos_label=0,
                )
                f1_pos_all.append(f1_pos)
                f1_neg_all.append(f1_neg)
                prec_neg_all.append(prec_neg)
                recall_neg_all.append(recall_neg)
                prec_pos_all.append(prec_pos)
                recall_pos_all.append(recall_pos)

                

        # ------- Loss calculate -------
        # loss_all = torch.tensor(loss_all)
        mask_all = torch.cat(mask_all)
        likelihoods_all = torch.cat(likelihoods_all)
        
        # stat['loss'] = loss_all.mean()
        for l in ("spec_loss", "entr_loss", "l_norm_loss", "clf_loss", "total_loss"):
            loss_per_batch_dict[l] = np.mean(loss_per_batch_dict[l])

        stat['likelihood_avg'] = likelihoods_all.mean()
        stat['likelihood_prod'] = torch.prod(likelihoods_all)
        stat['likelihood_logprod'] = torch.sum(likelihoods_all.log())
        stat['wiou'] = np.mean(wious_all) if len(wious_all) > 0 else np.nan
        stat['aucroc'] = np.mean(aucroc_all) if len(aucroc_all) > 0 else np.nan
        stat['f1_pos'] = np.mean(f1_pos_all) if len(f1_pos_all) > 0 else np.nan
        stat['f1_neg'] = np.mean(f1_neg_all) if len(f1_neg_all) > 0 else np.nan        
        stat['prec_neg'] = np.mean(prec_neg_all) if len(prec_neg_all) > 0 else np.nan        
        stat['recall_neg'] = np.mean(recall_neg_all) if len(recall_neg_all) > 0 else np.nan        
        stat['prec_pos'] = np.mean(prec_pos_all) if len(prec_pos_all) > 0 else np.nan        
        stat['recall_pos'] = np.mean(recall_pos_all) if len(recall_pos_all) > 0 else np.nan        

        # --------------- Metric calculation including ROC_AUC, Accuracy, AP.  --------------------
        stat['score'] = eval_score(pred_all, target_all, self.config, self.loader[split].dataset.minority_class)

        print(
            f'{split.capitalize()} {self.config.metric.score_name}: {stat["score"]:.4f} \t' + 
            f'{split.capitalize()} Loss: {loss_per_batch_dict["total_loss"]:.4f} \t' + 
            (f'{split.capitalize()} WIoU: {stat["wiou"]:.3f} \t' if compute_plaus else '') +
            (f'{split.capitalize()} AUCROC: {stat["aucroc"]:.3f} \t' if compute_plaus else '') +
            (f'{split.capitalize()} F1_pos: {stat["f1_pos"]:.3f} \t' if compute_mcc else '') +
            (f'{split.capitalize()} F1_neg: {stat["f1_neg"]:.3f} \t' if compute_mcc else '')
        )

        if was_training:
            self.model.train()

        return {
            'score': stat['score'],
            'loss': loss_per_batch_dict['total_loss'],
            'loss_dict': loss_per_batch_dict,
            'likelihood_avg': stat['likelihood_avg'],
            'likelihood_prod': stat['likelihood_prod'],
            'likelihood_logprod': stat['likelihood_logprod'],
            'wiou': stat['wiou'],
            'aucroc': stat['aucroc'],
            'f1_pos': stat['f1_pos'],
            'f1_neg': stat['f1_neg'],
            'prec_neg': stat['prec_neg'],
            'recall_neg': stat['recall_neg'],
            'prec_pos': stat['prec_pos'],
            'recall_pos': stat['recall_pos'],
            'pred': pred_all,
            'pred_clf_only': pred_clf_only_all
        }

    def load_task(self, load_param=False, load_split="ood"):
        r"""
        Launch a training or a test.
        """
        if self.task == 'train':
            self.train()
            return None, None
        elif self.task == 'test':
            # config model
            print('#D#Config model and output the best checkpoint info...')
            test_score, ckpt = self.config_model('test', load_param=load_param, load_split=load_split)
            return test_score, ckpt

    def config_model(self, mode: str, load_param=False, load_split="ood"):
        r"""
        A model configuration utility. Responsible for transiting model from CPU -> GPU and loading checkpoints.
        Args:
            mode (str): 'train' or 'test'.
            load_param: When True, loading test checkpoint will load parameters to the GNN model.

        Returns:
            Test score and loss if mode=='test'.
        """
        self.model.to(self.config.device)
        self.model.train()

        # load checkpoint
        if mode == 'train' and self.config.train.tr_ctn:
            assert False
            ckpt = torch.load(os.path.join(self.config.ckpt_dir, f'last.ckpt'))
            self.model.load_state_dict(ckpt['state_dict'])
            best_ckpt = torch.load(os.path.join(self.config.ckpt_dir, f'best.ckpt'))
            self.config.metric.best_stat['score'] = best_ckpt['val_score']
            self.config.metric.best_stat['loss'] = best_ckpt['val_loss']
            self.config.train.ctn_epoch = ckpt['epoch'] + 1
            print(f'#IN#Continue training from Epoch {ckpt["epoch"]}...')

        if mode == 'test':
            try:
                ckpt = torch.load(self.config.test_ckpt, map_location=self.config.device)
            except FileNotFoundError:
                print(f'#E#Checkpoint not found at {os.path.abspath(self.config.test_ckpt)}')
                exit(1)
            if os.path.exists(self.config.id_test_ckpt):
                id_ckpt = torch.load(self.config.id_test_ckpt, map_location=self.config.device)
                # model.load_state_dict(id_ckpt['state_dict'])
                print(f'#IN#Loading best In-Domain Checkpoint {id_ckpt["epoch"]} in {self.config.id_test_ckpt}')
                print(f'#IN#Checkpoint {id_ckpt["epoch"]}: \n-----------------------------------\n'
                      f'Train {self.config.metric.score_name}: {id_ckpt["train_score"]:.4f}\n'
                      f'Train Loss: {id_ckpt.get("train_loss", np.nan):.4f}\n'
                      f'Spec Loss: {id_ckpt.get("spec_loss", np.nan):.4f}\n'
                      f'Mean Loss: {id_ckpt.get("mean_loss", np.nan):.4f}\n'
                      f'Total Loss: {id_ckpt.get("total_loss", np.nan):.4f}\n'
                      f'ID Validation {self.config.metric.score_name}: {id_ckpt["id_val_score"]:.4f}\n'
                      f'ID Validation Loss: {id_ckpt["id_val_loss"].item():.4f}\n'
                      f'ID Test {self.config.metric.score_name}: {id_ckpt["id_test_score"]:.4f}\n'
                      f'ID Test Loss: {id_ckpt["id_test_loss"].item():.4f}\n'
                      f'OOD Validation {self.config.metric.score_name}: {id_ckpt["val_score"]:.4f}\n'
                      f'OOD Validation Loss: {id_ckpt["val_loss"].item():.4f}\n'
                      f'OOD Test {self.config.metric.score_name}: {id_ckpt["test_score"]:.4f}\n'
                      f'OOD Test Loss: {id_ckpt["test_loss"].item():.4f}\n')
                print(f'#IN#Loading best Out-of-Domain Checkpoint {ckpt["epoch"]}...')
                print(f'#IN#Checkpoint {ckpt["epoch"]}: \n-----------------------------------\n'
                      f'Train {self.config.metric.score_name}: {ckpt["train_score"]:.4f}\n'
                      f'Train Loss: {ckpt["train_loss"].item():.4f}\n'
                      f'ID Validation {self.config.metric.score_name}: {ckpt["id_val_score"]:.4f}\n'
                      f'ID Validation Loss: {ckpt["id_val_loss"].item():.4f}\n'
                      f'ID Test {self.config.metric.score_name}: {ckpt["id_test_score"]:.4f}\n'
                      f'ID Test Loss: {ckpt["id_test_loss"].item():.4f}\n'
                      f'OOD Validation {self.config.metric.score_name}: {ckpt["val_score"]:.4f}\n'
                      f'OOD Validation Loss: {ckpt["val_loss"].item():.4f}\n'
                      f'OOD Test {self.config.metric.score_name}: {ckpt["test_score"]:.4f}\n'
                      f'OOD Test Loss: {ckpt["test_loss"].item():.4f}\n')

                print(f'#IN#ChartInfo {id_ckpt["id_test_score"]:.4f} {id_ckpt["test_score"]:.4f} '
                      f'{ckpt["id_test_score"]:.4f} {ckpt["test_score"]:.4f} {ckpt["id_val_score"]:.4f} {ckpt["val_score"]:.4f}', end='')
            else:
                print(f'#IN#No In-Domain checkpoint.')
                # model.load_state_dict(ckpt['state_dict'])
                print(f'#IN#Loading best Checkpoint {ckpt["epoch"]}...')
                print(f'#IN#Checkpoint {ckpt["epoch"]}: \n-----------------------------------\n'
                      f'Train {self.config.metric.score_name}: {ckpt["train_score"]:.4f}\n'
                      f'Train Loss: {ckpt["train_loss"].item():.4f}\n'
                      f'Validation {self.config.metric.score_name}: {ckpt["val_score"]:.4f}\n'
                      f'Validation Loss: {ckpt["val_loss"].item():.4f}\n'
                      f'Test {self.config.metric.score_name}: {ckpt["test_score"]:.4f}\n'
                      f'Test Loss: {ckpt["test_loss"].item():.4f}\n')

                print(
                    f'#IN#ChartInfo {ckpt["test_score"]:.4f} {ckpt["val_score"]:.4f}', end='')
            if load_param:
                if self.config.ood.ood_alg != 'EERM':
                    if load_split == "ood":
                        self.model.load_state_dict(ckpt['state_dict'])
                    elif load_split == "id":
                        self.model.load_state_dict(id_ckpt['state_dict'])
                    else:
                        raise ValueError(f"{load_split} not supported")
                else:
                    self.model.gnn.load_state_dict(ckpt['state_dict'])
            return ckpt["test_score"], id_ckpt


    def save_epoch(self, epoch: int, train_stat: dir, id_val_stat: dir, id_test_stat: dir, val_stat: dir,
                   test_stat: dir, config: Union[CommonArgs, Munch], loss_per_batch_dict: dict, manual_save:str=None):
        r"""
        Training util for checkpoint saving.

        Args:
            epoch (int): epoch number
            train_stat (dir): train statistics
            id_val_stat (dir): in-domain validation statistics
            id_test_stat (dir): in-domain test statistics
            val_stat (dir): ood validation statistics
            test_stat (dir): ood test statistics
            config (Union[CommonArgs, Munch]): munchified dictionary of args (:obj:`config.ckpt_dir`, :obj:`config.dataset`, :obj:`config.train`, :obj:`config.model`, :obj:`config.metric`, :obj:`config.log_path`, :obj:`config.ood`)

        Returns:
            None

        """
        state_dict = self.model.state_dict() if config.ood.ood_alg != 'EERM' else self.model.gnn.state_dict()
        ckpt = {
            'state_dict': state_dict,
            'train_score': train_stat['score'],
            'train_loss': train_stat['loss'],
            'id_val_score': id_val_stat['score'],
            'id_val_loss': id_val_stat['loss'],
            'id_test_score': id_test_stat['score'],
            'id_test_loss': id_test_stat['loss'],
            'val_score': val_stat['score'],
            'val_loss': val_stat['loss'],
            'test_score': test_stat['score'],
            'test_loss': test_stat['loss'],
            'time': datetime.datetime.now().strftime('%b%d %Hh %M:%S'),
            'model': {
                'model name': f'{config.model.model_name} {config.model.model_level} layers',
                'dim_hidden': config.model.dim_hidden,
                'dim_ffn': config.model.dim_ffn,
                'global pooling': config.model.global_pool
            },
            'dataset': config.dataset.dataset_name,
            'train': {
                'weight_decay': config.train.weight_decay,
                'learning_rate': config.train.lr,
                'mile stone': config.train.mile_stones,
                'shift_type': config.dataset.shift_type,
                'Batch size': f'{config.train.train_bs}, {config.train.val_bs}, {config.train.test_bs}'
            },
            'OOD': {
                'OOD alg': config.ood.ood_alg,
                'OOD param': config.ood.ood_param,
                'number of environments': config.dataset.num_envs
            },
            'log file': config.log_path,
            'epoch': epoch,
            'max epoch': config.train.max_epoch
        }
        ckpt.update(loss_per_batch_dict)

        if epoch < config.train.pre_train:
            return

        # WARNING: Original reference metric is 'score'
        reference_metric = "loss"
        lower_better = 1 if reference_metric == "loss" else -1

        if not (config.metric.best_stat[reference_metric] is None or 
                lower_better * val_stat[reference_metric] < lower_better *
                config.metric.best_stat[reference_metric]
            or (id_val_stat.get(reference_metric) and (
                        config.metric.id_best_stat[reference_metric] is None or 
                        lower_better * id_val_stat[reference_metric] < lower_better * config.metric.id_best_stat[reference_metric]))
            or epoch % config.train.save_gap == 0):
            return

        if not os.path.exists(config.ckpt_dir):
            os.makedirs(config.ckpt_dir)
            print(f'#W#Directory does not exists. Have built it automatically.\n'
                  f'{os.path.abspath(config.ckpt_dir)}')
        
        saved_file = os.path.join(config.ckpt_dir, f'{epoch}.ckpt')
        torch.save(ckpt, saved_file)
        shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'last.ckpt'))

        if manual_save is not None:
            print(f'#W#Saving manual checkpoint {manual_save}.ckpt')
            shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'{manual_save}.ckpt'))

        # --- In-Domain checkpoint ---
        if id_val_stat.get(reference_metric) and (
                config.metric.id_best_stat[reference_metric] is None or lower_better * id_val_stat[
            reference_metric] < lower_better * config.metric.id_best_stat[reference_metric]):
            config.metric.id_best_stat['score'] = id_val_stat['score']
            config.metric.id_best_stat['loss'] = id_val_stat['loss']
            shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'id_best.ckpt'))
            print('#IM#Saved a new best In-Domain checkpoint.')

        # --- Out-Of-Domain checkpoint ---
        if config.metric.best_stat[reference_metric] is None or lower_better * val_stat[
            reference_metric] < lower_better * \
                config.metric.best_stat[reference_metric]:
            config.metric.best_stat['score'] = val_stat['score']
            config.metric.best_stat['loss'] = val_stat['loss']
            shutil.copy(saved_file, os.path.join(config.ckpt_dir, f'best.ckpt'))
            print('#IM#Saved a new best checkpoint.')
        
        if config.clean_save:
            os.unlink(saved_file)

    def get_node_explanations(self, num_samples=None):
        self.model.eval()

        splits = ["id_val"]
        ret = {
            split: {
                "scores": [],
                "samples": [],
                "pred": []
            } for split in splits
        }
                
        for i, split in enumerate(splits):
            dataset = self.get_local_dataset(split)

            if num_samples:
                dataset = dataset[:num_samples]

            loader = DataLoader(dataset, batch_size=256, shuffle=False, num_workers=2)
            for data in loader:
                data: Batch = data.to(self.config.device)

                # Manually manipulate colors setting to black (for plotting purposes, mainly)                
                # data.x[data.sp_order == 0, :3] = 0.0
                # num_max_sp_per_batch = scatter_max(data.sp_order, index=data.batch)[0][data.batch]
                # data.x[data.sp_order == num_max_sp_per_batch, :3] = 0.0

                # Manually manipulate colors setting to the color of another class (for plotting purposes, mainly) 
                # class_to_which_induce_color = 9
                # new_color = torch.tensor(dataset.color_mapping[class_to_which_induce_color], dtype=data.x.dtype, device=data.x.device)
                # data.x[data.sp_order == 0, :3] = new_color
                # num_max_sp_per_batch = scatter_max(data.sp_order, index=data.batch)[0][data.batch]
                # data.x[data.sp_order == num_max_sp_per_batch, :3] = new_color
                
                # Manually manipulate digits (for plotting purposes, mainly)
                # data.x[data.node_label.bool(), :3] = 0.0

                edge_scores, node_scores, logits = self.model.get_subgraph(
                    data=data,
                    edge_weight=None,
                    ood_algorithm=self.ood_algorithm,
                    do_relabel=False
                )

                for j, g in enumerate(data.to_data_list()):
                    node_expl = node_scores[data.batch == j].detach()#.cpu().numpy().squeeze(1)

                    if self.config.model.model_name == "DIR":
                        # remove nodes not in the TopK
                        (_),(_), \
                            (topK_nodes_kept, topK_nodes_removed) = split_graph_node(g, node_expl, self.config.ood.ood_param, embed=None, use_input_feat=True)
                        assert topK_nodes_kept.shape[0] + topK_nodes_removed.shape[0] == g.x.shape[0]
                        node_expl[topK_nodes_removed] = -1.0

                    node_expl = node_expl.cpu().numpy().squeeze(1)

                    # normalize scores when squashed to zero
                    # node_expl = (node_expl - node_expl.min()) / (node_expl.max() - node_expl.min())

                    ret[split]["scores"].append(node_expl.tolist())
                    ret[split]["samples"].append(g.cpu())
                    ret[split]["pred"].append(logits[j].cpu())
        return ret

    def generate_panel(self):
        self.model.eval()

        splits = ["train", "id_val", "id_test"] #, "test"
        n_row = 1
        fig, axs = plt.subplots(n_row, len(splits), figsize=(9,4))
        # plt.suptitle(f"{self.config.model.model_name[:4]}") # - {self.config.dataset.dataset_name} {self.config.dataset.domain}
        
        for i, split in enumerate(splits):            
            # acc = self.evaluate(split, compute_suff=False)["score"]
            # print(f"Acc ({split}) =  ({acc:.3f}%)")
            dataset = self.get_local_dataset(split)

            loader = DataLoader(dataset, batch_size=512, shuffle=False, num_workers=2)
            edge_scores, effective_ratios = [], []
            for data in loader:
                data: Batch = data.to(self.config.device)   
                edge_score = self.model.get_subgraph(
                                data=data,
                                edge_weight=None,
                                ood_algorithm=self.ood_algorithm,
                                do_relabel=False
                        )
                for j, g in enumerate(data.to_data_list()):
                    edge_scores.append(edge_score[data.batch[data.edge_index[0]] == j].detach().cpu().numpy().tolist())
                    if g.edge_index.shape[1] > 0:
                        effective_ratios.append(float((g.edge_gt.sum() if hasattr(g, "edge_gt") and not g.edge_gt is None else 0.) / (g.edge_index.shape[1])))

            return edge_scores

        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        path = f'GOOD/kernel/pipelines/plots/panels/{self.config.ood_dirname}/'
        if not os.path.exists(path):
            os.makedirs(path)

        path += f"{self.config.load_split}_{self.config.dataset.dataset_name}_{self.config.dataset.domain}_{self.config.util_model_dirname}_{self.config.random_seed}"
        plt.savefig(path + ".png")
        plt.savefig(f'GOOD/kernel/pipelines/plots/panels/pdfs/{self.config.load_split}_{self.config.dataset.dataset_name}_{self.config.dataset.domain}_{self.config.util_model_dirname}_{self.config.random_seed}.pdf')
        print("\n Saved plot ", path, "\n")
        plt.close()
        return edge_scores

    

    
    
