import copy
import json
import os
import random

from tqdm import tqdm

import numpy as np
import torch
from torch.autograd import Variable
from torch.nn import Linear
import torch.nn.functional as F

from torch_geometric.datasets import TUDataset, BA2MotifDataset, BAMultiShapesDataset
from torch_geometric.nn import global_add_pool
from torch_geometric.nn import Sequential, GAT, GCN, GIN
from torch_geometric.utils import to_dense_adj, subgraph
from torch_geometric.data import Data

import networkx as nx

from scipy.stats import spearmanr
import shap


METRICS = {
    "A": "Accuracy",
    "A1": "Accuracy (instance-level)",
    "A2": "Accuracy (model-level)",
    "I1": "Completeness\n(with)",
    "I2": "Completeness\n(without)",
    "I3": "Consistency",
    "I4": "Continuity (nodes)",
    "I5": "Continuity (edges)",
    "I6": "Contrastivity",
    "I7": "Compactness",
    "M1": "Correctness (nodes)",
    "M2": "Correctness (edges)",
    "M3": "Compactness",
}


def find_five_node_cycles(graph_data, return_all=False):
    G = nx.Graph()
    G.add_edges_from(graph_data.edge_index.t().tolist())
    cycles = list()
    for cycle in nx.simple_cycles(nx.DiGraph(G)):
        if len(cycle) == 5:
            if return_all:
                cycles.append(cycle)
            else:
                return cycle
    return cycles if return_all else None


def find_house_motif(graph_data):
    G = nx.Graph()
    G.add_edges_from(graph_data.edge_index.t().tolist())
    cycles = find_five_node_cycles(graph_data, return_all=True)
    for cycle in cycles:
        for node in cycle:
            neighbors = [v for v in list(G.neighbors(node)) if v in cycle]
            if len(neighbors) == 3:
                return cycle
    return []


def find_wheel_motif(graph_data):
    G = nx.Graph()
    G.add_edges_from(graph_data.edge_index.t().tolist())
    cycles = find_five_node_cycles(graph_data, return_all=True)
    for cycle in cycles:
        for node in G.nodes:
            if (node not in cycle) and sorted(G.neighbors(node)) == sorted(cycle):
                return cycle + [node]
    return []


def is_3x3_grid(subgraph):
    expected_edges = set(
        [(0, 1), (1, 2), (3, 4), (4, 5), (6, 7), (7, 8), (0, 3), (1, 4), (2, 5), (3, 6), (4, 7), (5, 8)]
    )

    mapping = {node: idx for idx, node in enumerate(subgraph.nodes())}
    relabeled_subgraph = nx.relabel_nodes(subgraph, mapping)
    actual_edges = set([tuple(sorted(edge)) for edge in relabeled_subgraph.edges()])
    return expected_edges.issubset(actual_edges)


# Function to find 3x3 grids in the graph
def find_grid_motif(graph_data):
    G = nx.Graph()
    G.add_edges_from(graph_data.edge_index.t().tolist())

    for cycle in nx.simple_cycles(nx.DiGraph(G)):
        if len(cycle) == 8:
            for node in G.nodes:
                if node not in cycle:
                    subgraph = G.subgraph(cycle + [node])
                    if is_3x3_grid(subgraph):  # Check if the subgraph matches the 3x3 grid
                        return cycle + [node]
    return []


def get_mutag_gt(data):
    G = nx.Graph()
    G.add_edges_from(data.edge_index.t().tolist())
    nx.simple_cycles(nx.DiGraph(G))
    explanation_nodes_subset = list()

    def find_atom(x, idx, label):
        idx = torch.tensor(idx, dtype=torch.long)
        one_hot = F.one_hot(torch.tensor([label], dtype=torch.long), num_classes=7)
        mask = (x[idx] == one_hot).amin(dim=-1)
        return idx[mask].tolist()

    for cycle in nx.simple_cycles(nx.DiGraph(G)):
        nodes_C = find_atom(data.x, cycle, 0)
        if len(nodes_C) == len(cycle):
            explanation_nodes_subset += cycle

    if data.y == 0:
        no2_group = None
        nodes_N = find_atom(data.x, list(range(len(data.x))), 1)
        for n in nodes_N:
            neighbors = [i for i in G.neighbors(n)]
            nodes_O = find_atom(data.x, neighbors, 2)
            if len(nodes_O) > 0:
                assert len(nodes_O) == 2
                no2_group = nodes_O + [n]
                explanation_nodes_subset += no2_group
        if no2_group is None:
            explanation_nodes_subset = []
    explanation_nodes_subset = torch.tensor(list(set(explanation_nodes_subset)), dtype=torch.long)
    if len(explanation_nodes_subset) > 0:
        edge_index, _ = subgraph(explanation_nodes_subset, data.edge_index)
        data.explanation_nodes_subset = explanation_nodes_subset
        data.explanation_graph = edge_index
    else:
        data.explanation_nodes_subset = None
        data.explanation_graph = None
    return data


def get_gt_explanation_ba_2motif(data):
    gt_explanations = get_gt_explanations_model("BA-2motif")
    cycle = [e for e in gt_explanations if e.y.item() == 0][0]
    house = [e for e in gt_explanations if e.y.item() == 1][0]

    if data.y == 0:
        data.explanation_graph = cycle
    elif data.y == 1:
        data.explanation_graph = house
    else:
        assert False, data.y
    data.explanation_nodes_subset = torch.tensor([20, 21, 22, 23, 24]).long()
    return data


def get_gt_explanation_ba_multishapes(data):
    gt_explanations = get_gt_explanations_model("BAMultiShapes")
    gt_house = gt_explanations[0]
    gt_wheel = gt_explanations[1]
    gt_grid = gt_explanations[2]

    house = torch.tensor(find_house_motif(data)).long()
    wheel = torch.tensor(find_wheel_motif(data)).long()
    grid = torch.tensor(find_grid_motif(data)).long()
    if len(house) == 0 and len(wheel) == 0 and len(grid) == 0:
        data.explanation_nodes_subset = None
        data.explanation_graph = None
    elif len(house) == 5 and len(wheel) == 0 and len(grid) == 0:
        data.explanation_nodes_subset = house
        data.explanation_graph = gt_house.clone()
    elif len(house) == 0 and len(wheel) == 6 and len(grid) == 0:
        data.explanation_nodes_subset = wheel
        data.explanation_graph = gt_wheel.clone()
    elif len(house) == 0 and len(wheel) == 0 and len(grid) == 9:
        data.explanation_nodes_subset = grid
        data.explanation_graph = gt_grid.clone()
    elif len(house) == 5 and len(wheel) == 6 and len(grid) == 9:
        data.explanation_nodes_subset = torch.cat([house, wheel, grid], dim=0)
        data.explanation_graph = torch.cat([house, wheel + 5, grid + 5 + 6], dim=0)
    elif len(house) == 5 and len(wheel) == 6 and len(grid) == 0:
        data.explanation_nodes_subset = torch.cat([house, wheel], dim=0)
        data.explanation_graph = torch.cat([house, wheel + 5], dim=0)
    elif len(house) == 5 and len(wheel) == 0 and len(grid) == 9:
        data.explanation_nodes_subset = torch.cat([house, grid], dim=0)
        data.explanation_graph = torch.cat([house, grid + 5], dim=0)
    elif len(house) == 0 and len(wheel) == 6 and len(grid) == 9:
        data.explanation_nodes_subset = torch.cat([wheel, grid], dim=0)
        data.explanation_graph = torch.cat([wheel, grid + 6], dim=0)
    else:
        assert False, (house, wheel, grid)
    return data


def get_dataset(dataset_name, degree_attr=False, use_ones=False, use_node_attr=False):
    def transform_fn(data):
        if degree_attr:
            adj = to_dense_adj(data.edge_index)[0]
            data.x = adj.sum(-1).unsqueeze(-1)
        elif use_ones:
            data.x = torch.ones(data.num_nodes, 1)
        else:
            data.x = 0.1 * torch.ones(data.num_nodes, 10)
        return data

    if dataset_name == "BA-2motif":
        dataset = BA2MotifDataset(f"datasets/saved/{dataset_name}", force_reload=True)
    elif dataset_name == "BAMultiShapes":

        dataset = BAMultiShapesDataset(f"datasets/saved/{dataset_name}", force_reload=True)
    else:
        transform = transform_fn if (dataset_name in ["IMDB-BINARY", "IMDB-MULTI", "REDDIT-BINARY"]) else None
        dataset = TUDataset(
            f"datasets/saved",
            dataset_name,
            use_node_attr=use_node_attr,
            transform=transform,
        )
    return dataset


def get_splits(dataset_name, size=None, seed=123, split=-1):
    if split < 0:
        generator = torch.Generator()
        generator.manual_seed(seed)
        np.random.seed(seed)
        perm = torch.randperm(size, generator=generator)
        val_idx, test_idx = int(size * 0.8), int(size * 0.9)
        train_idxs, val_idxs, test_idxs = perm[:val_idx], perm[val_idx:test_idx], perm[test_idx:]
    else:
        with open(os.path.join("datasets/splits", f"{dataset_name}.json")) as f:
            splits = json.load(f)
        split = splits[split]
        test_idxs = split["test"]
        train_idxs, val_idxs = split["model_selection"][0]["train"], split["model_selection"][0]["validation"]
    return train_idxs, val_idxs, test_idxs


def get_model(model_type, num_node_features, num_classes, hidden_dim=32, num_layers=3, linear_dim=32):
    if model_type == "GCN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GCN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        out_channels=linear_dim,
                        dropout=0.4,
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(linear_dim, num_classes),
            ],
        )

    elif model_type == "GAT":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GAT(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        out_channels=32,
                        act="elu",
                        dropout=0.6,
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(32, num_classes),
            ],
        )
    elif model_type == "GIN":
        model = Sequential(
            "x, edge_index, batch",
            [
                (
                    GIN(
                        num_node_features,
                        hidden_dim,
                        num_layers=num_layers,
                        out_channels=32,
                        norm="batch_norm",
                        dropout=0.4,
                    ),
                    "x, edge_index -> x",
                ),
                (global_add_pool, "x, batch -> x"),
                Linear(32, num_classes),
            ],
        )
    else:
        assert False
    return model


def train(model, optimizer, data):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, batch=data.batch)
    loss = F.cross_entropy(out, data.y)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 2.0)
    loss.backward()
    optimizer.step()
    return float(loss)


@torch.no_grad()
def test(model, data):
    model.eval()
    out = model(data.x, data.edge_index, batch=data.batch)
    pred = out.argmax(dim=-1)
    loss = F.cross_entropy(out, data.y)
    acc = (pred == data.y).float().mean()
    return loss, acc


def train_model(
    model, dataloader_train, dataloader_val, dataloader_test, epochs=200, lr=0.001, weight_decay=0.005
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    pbar = tqdm(range(1, epochs + 1))
    max_acc_train, max_acc_val, max_acc_test = 0, 0, 0
    bast_state_dict = copy.deepcopy(model.state_dict())

    for epoch in pbar:
        model.train()
        loss_train, acc_train, loss_val, acc_val, loss_test, acc_test = (
            list(),
            list(),
            list(),
            list(),
            list(),
            list(),
        )

        for data in dataloader_train:
            train(model, optimizer, data.to(device))

        if epoch % 10 == 0:
            for data in dataloader_train:
                loss, acc = test(model, data.to(device))
                loss_train.append(loss.item()), acc_train.append(acc.item())
            for data in dataloader_val:
                loss, acc = test(model, data.to(device))
                loss_val.append(loss.item()), acc_val.append(acc.item())
            for data in dataloader_test:
                loss, acc = test(model, data.to(device))
                loss_test.append(loss.item()), acc_test.append(acc.item())
            loss_train, acc_train, loss_val, acc_val, loss_test, acc_test = (
                np.mean(loss_train),
                np.mean(acc_train),
                np.mean(loss_val),
                np.mean(acc_val),
                np.mean(loss_test),
                np.mean(acc_test),
            )
            if max_acc_val < acc_val:
                max_acc_train, max_acc_val, max_acc_test = acc_train, acc_val, acc_test
                bast_state_dict = copy.deepcopy(model.state_dict())

            pbar.set_description(
                f"Train l:{loss_train:.4f} a:{acc_train:.4f} | Val l:{loss_val:.4f} a:{acc_val:.4f} | Test l:{loss_test:.4f} a:{acc_test:.4f}"
            )
    print(f"Train a:{max_acc_train:.4f} | Val a:{max_acc_val:.4f} | Test a:{max_acc_test:.4f}")
    pbar.close()
    model.load_state_dict(bast_state_dict)
    model.eval()

    return model


def get_acc(model, dataloader, forward_fn):
    acc, count = 0, 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)
            out = forward_fn(model, data)
            acc += (out.argmax(dim=-1) == data.y).float().sum()
            count += len(data.y)
    return (acc.item() / count) if count > 0 else 0.0


def get_gt_explanations_model(dataset_name):
    gt_dataset = list()
    cls_0, cls_1 = torch.tensor([0], dtype=torch.long), torch.tensor([1], dtype=torch.long)

    if dataset_name == "MUTAG":
        edge_index = torch.tensor([[0, 0], [1, 2]], dtype=torch.long)
        x = F.one_hot(torch.tensor([1, 2, 2]).long(), num_classes=7)
        gt_dataset.append(Data(edge_index=edge_index, x=x, y=cls_0.clone(), num_nodes=edge_index.max() + 1))
        edge_index = torch.tensor([[0, 1, 2, 3, 4, 0], [1, 2, 3, 4, 5, 5]], dtype=torch.long)
        x = F.one_hot(torch.tensor([0, 0, 0, 0, 0, 0]).long(), num_classes=7)
        gt_dataset.append(Data(edge_index=edge_index, x=x, y=cls_1.clone(), num_nodes=edge_index.max() + 1))

    elif dataset_name == "REDDIT-BINARY":
        edge_index = torch.tensor([[0, 0, 0], [1, 2, 3]]).long()
        gt_dataset.append(Data(edge_index=edge_index, y=cls_0.clone(), num_nodes=edge_index.max() + 1))
        edge_index = torch.tensor([[0, 0, 0, 0, 1, 1, 1, 1], [2, 3, 4, 5, 2, 3, 4, 5]], dtype=torch.long())
        gt_dataset.append(Data(edge_index=edge_index, y=cls_1.clone(), num_nodes=edge_index.max() + 1))

    elif dataset_name == "BA-2motif":
        edge_index = torch.tensor([[0, 1, 2, 3, 0], [1, 2, 3, 4, 4]], dtype=torch.long)
        gt_dataset.append(Data(edge_index=edge_index, y=cls_0.clone(), num_nodes=edge_index.max() + 1))
        edge_index = torch.tensor([[0, 1, 2, 3, 0, 1], [1, 2, 3, 4, 4, 4]], dtype=torch.long)
        gt_dataset.append(Data(edge_index=edge_index, y=cls_1.clone(), num_nodes=edge_index.max() + 1))
    elif dataset_name == "BAMultiShapes":
        house = torch.tensor([[0, 1, 2, 3, 0, 1], [1, 2, 3, 4, 4, 4]], dtype=torch.long)
        wheel = torch.tensor(
            [[0, 1, 2, 3, 0, 1, 0, 1, 2, 3, 4], [1, 2, 3, 4, 4, 4, 5, 5, 5, 5, 5]], dtype=torch.long
        )
        grid = torch.tensor(
            [[0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 6, 7], [1, 2, 2, 4, 5, 4, 6, 5, 7, 8, 7, 8]], dtype=torch.long
        )
        gt_dataset.append(Data(edge_index=house.clone(), y=cls_0.clone(), num_nodes=house.amax() + 1))
        gt_dataset.append(Data(edge_index=wheel.clone(), y=cls_0.clone(), num_nodes=wheel.amax() + 1))
        gt_dataset.append(Data(edge_index=grid.clone(), y=cls_0.clone(), num_nodes=grid.amax() + 1))
    else:
        return []
    return gt_dataset


def compare_lists(init_data, after_data, distance_fn):
    return [distance_fn(init_d, after_d) for init_d, after_d in zip(init_data, after_data)]


def find_elbow(values, softmax=False):
    torch_y = values.sort()[0]
    y = torch_y
    if softmax:
        y = F.softmax(torch_y)
    y = y.cpu().detach().numpy()
    x = np.arange(len(y))
    line_vec = np.array([x[-1] - x[0], y[-1] - y[0]])
    line_vec = line_vec / np.linalg.norm(line_vec)
    vec_from_first = np.array([x - x[0], y - y[0]]).T
    proj_onto_line = np.dot(vec_from_first, line_vec)
    distance_to_line = np.linalg.norm(vec_from_first - proj_onto_line[:, None] * line_vec, axis=1)
    elbow_idx = np.argmax(distance_to_line)
    return torch_y[elbow_idx]


def find_percentile(values, p=0.85, softmax=True):
    y, indices = values.sort()
    if not softmax:
        assert torch.all(y >= 0)
    cumsum = torch.cumsum(F.softmax(y, dim=0), dim=0) if softmax else torch.cumsum(y / sum(y), dim=0)
    breakpoint = indices[(cumsum <= p)]
    if len(breakpoint) > 0:
        breakpoint = breakpoint[-1]
    else:
        breakpoint = 0
    return values[breakpoint]


def pairwise_list_corr(x, to_abs=True):
    correlations = list()
    for i in range(x.shape[-1]):
        for j in range(i + 1, x.shape[-1]):
            corr, _ = spearmanr(x[:, i], x[:, j])
            correlations.append(np.abs(corr) if to_abs else corr)
    corrs = [corr if not np.isnan(corr) else 1.0 for corr in correlations]
    return np.mean(corrs)


def get_shap(classifier, inputs):
    torch.set_grad_enabled(True)
    explainer = shap.DeepExplainer(classifier, Variable(inputs))
    shap_values = explainer.shap_values(Variable(inputs))
    shap_values = np.stack(shap_values, axis=0)
    return torch.from_numpy(shap_values), explainer


def run_shap(explainer, input):
    try:
        shap = torch.from_numpy(np.stack(explainer.shap_values(Variable(input))))
    except:
        shap = list()
        for i in range(input.shape[0]):
            try:
                shap += [torch.from_numpy(np.stack(explainer.shap_values(Variable(input[i : i + 1]))))]
            except:
                shap += [torch.zeros(explainer.expected_value.shape[0], 1, input.shape[-1])]
        shap = torch.stack(shap, dim=1)
    return shap


def get_edge_mask(edge_index, nidx, isin=True):
    nidx_tensor = torch.tensor(nidx, dtype=edge_index.dtype, device=edge_index.device)

    if isin:
        src_nodes_in_nidx = torch.isin(edge_index[0], nidx_tensor)
        tgt_nodes_in_nidx = torch.isin(edge_index[1], nidx_tensor)
    else:
        device = edge_index.device
        src_nodes_in_nidx = torch.tensor(
            [i in nidx_tensor for i in edge_index[0]], dtype=torch.bool, device=device
        )
        tgt_nodes_in_nidx = torch.tensor(
            [i in nidx_tensor for i in edge_index[1]], dtype=torch.bool, device=device
        )

    mask = src_nodes_in_nidx & tgt_nodes_in_nidx
    return mask.bool()


def remove_edges(edge_index, edges_to_remove):
    edge_index_set = set(tuple(sorted(edge_index[:, i].tolist())) for i in range(edge_index.size(1)))
    edges_to_remove_set = set(
        tuple(sorted(edges_to_remove[:, i].tolist())) for i in range(edges_to_remove.size(1))
    )

    new_edge_index_set = edge_index_set - edges_to_remove_set
    if len(new_edge_index_set) > 0:
        return torch.tensor(list(new_edge_index_set)).t().contiguous()
    else:
        return torch.zeros([2, 0]).long()


def add_noise_perturb(dataset, p_x=0.05, p_edges_del=0.01, p_edges_add=0.01, skip_last=False, explanations=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if skip_last:
        all_x = torch.cat([data.x[:-1] for data in dataset])
    else:
        all_x = torch.cat([data.x for data in dataset])

    def _add_noise_perturb(data):
        data, e = data
        if e is not None:
            e = e.nodes_subset
        x, edge_index = data.x.clone().to(device), data.edge_index.clone().to(device)
        n, m = len(data.x), data.edge_index.shape[-1]
        if p_x > 0:
            keys = torch.tensor(np.random.choice([0, 1], p=[1 - p_x, p_x], size=len(x))).bool()
            if skip_last:
                keys[-1] = False
            if e is not None:
                keys[e] = False
            size = keys.sum().item()
            old_x = x.clone()
            if size > 0:
                x[keys] = all_x[np.random.randint(low=0, high=len(all_x), size=size)].to(device)
        if p_edges_del > 0:
            idx = torch.tensor(np.random.choice([0, 1], p=[1 - p_edges_del, p_edges_del], size=m)).bool()
            if e is not None:
                e = e.to(device)
                keep = [((edge_index[0, i] in e) and (edge_index[1, i] in e)) for i in range(edge_index.shape[1])]
                keep = torch.tensor(keep).to(device)
                idx[keep] = False
            if idx.sum() > 0:
                edge_index = remove_edges(edge_index, edge_index[:, idx.bool()])
        if p_edges_add > 0:
            new_edges = torch.tensor([[i, j] for i in range(0, len(x)) for j in range(i + 1, len(x))]).T.to(
                device
            )
            new_edges = remove_edges(new_edges, edge_index).to(device)
            idx = (
                torch.tensor(
                    np.random.choice([0, 1], p=[1 - p_edges_add, p_edges_add], size=new_edges.shape[-1])
                )
                .bool()
                .to(device)
            )
            if idx.sum() > 0:
                new_edges = new_edges[:, idx]
                edge_index = edge_index.to(device)
                edge_index = torch.cat([edge_index, edge_index[[1, 0]], new_edges, new_edges[[1, 0]]], dim=-1)
        return Data(x=x.cpu(), edge_index=edge_index.cpu(), y=data.y.cpu())

    if explanations is None:
        explanations = [None] * len(dataset)
    return list(map(_add_noise_perturb, zip(dataset, explanations)))


def find_topk(values, k):
    topk_value = torch.topk(values, k=min(k, len(values)))[0][-1]
    return topk_value


def extract_subset_and_edge_index(
    edge_index, num_nodes, node_mask=None, node_mask_fn=None, edge_mask=None, edge_mask_fn=None
):
    device = edge_index.device
    subset = torch.arange(num_nodes, device=device)
    if node_mask is not None:
        thr = node_mask_fn(node_mask)

        if thr is not None:
            node_mask = node_mask >= thr
            subset = subset[node_mask]

    if edge_mask is not None:
        thr = edge_mask_fn(edge_mask)
        if thr is not None:
            edge_mask = edge_mask >= thr
            edge_index = edge_index[:, edge_mask]
            if (node_mask is None) or (node_mask_fn(node_mask.float()) is None):
                subset = torch.unique(edge_index)
    return subset, edge_index


def print_results(results):
    print(",".join(f"{k}" for k in results.keys()))
    print(",".join(f"{v:.4f}" if v is not None else "nan" for v in results.values()))


def get_fn(str_fn):
    fn, param = str_fn.split(":") if (":" in str_fn) else ("none", 0)

    return {
        "elbow": lambda x: find_elbow(x, softmax=False),
        "percentile": lambda x: find_percentile(x, float(param), softmax=True),
        "topk": lambda x: find_topk(x, int(param)),
        "none": lambda x: None,
        "elbow_softmax": lambda x: find_elbow(x, softmax=True),
    }[fn]


def get_responses_and_logits(model, dataloader, forward_fn):
    logits, responses = list(), list()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)
            l, r = forward_fn(model, data)
            responses.append(r), logits.append(l)
    return torch.cat(responses, 0), torch.cat(logits)


def remove_explanation_perturb(explanations, node_mask_fn, edge_mask_fn, skip_last_node):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def _remove_explanation_perturb(e):
        e = e.to(device)
        subset, edge_index = extract_subset_and_edge_index(
            e.edge_index,
            len(e.x) - int(skip_last_node),
            e.node_mask.sum(1)[:-1] if skip_last_node else e.node_mask.sum(1),
            node_mask_fn,
            e.edge_mask.sum(1),
            edge_mask_fn,
        )
        edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True, num_nodes=len(e.x))

        y_pred = e.pred.argmax().unsqueeze(0)
        if len(subset) > 0:
            with_explanation = Data(edge_index=edge_index, x=e.x[subset], y=y_pred)
        else:
            with_explanation = None
        mask = torch.ones(len(e.x), dtype=torch.bool, device=device)
        mask[subset] = False
        if skip_last_node:
            mask[-1] = False
        subset = torch.arange(len(e.x), device=device)[mask]
        if len(subset) == 0:
            without_explanation = None
        else:
            edge_index, _ = subgraph(subset, e.edge_index, relabel_nodes=True, num_nodes=len(e.x))
            without_explanation = Data(edge_index=edge_index, x=e.x[subset], y=y_pred)
        return with_explanation, without_explanation

    with_without_dataset = list(map(_remove_explanation_perturb, explanations))
    with_without_dataset = zip(*with_without_dataset)
    return with_without_dataset


def closest_prototype_acc(model, cls, dataloader, forward_fn):
    acc, count = 0, 0
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cls = cls.to(device)
    with torch.no_grad():
        for data in dataloader:
            data = data.to(device)
            sim = forward_fn(model, data)
            acc += (cls[sim.argmax(-1)] == data.y).float().sum()
            count += len(data.y)
    return (acc.item() / count) if count > 0 else 0.0


def compare_with_gt_model(explanations, ground_truths, distance_fn):
    distances = list()
    for e in explanations:
        e_list = [e] * len(ground_truths)
        distances += [min(compare_lists(ground_truths, e_list, distance_fn))]
    return distances


def compare_with_gt_instance(dataset_name, dataset, truncated_explanations, distance_fn):
    distances = list()
    for data, pred in zip(dataset, truncated_explanations):
        get_gt_fn = {
            "MUTAG": get_mutag_gt,
            "BAMultiShapes": get_gt_explanation_ba_multishapes,
            "BA-2motif": get_gt_explanation_ba_2motif,
        }
        if dataset_name in get_gt_fn:
            get_gt_fn = get_gt_fn[dataset_name]
        else:
            get_gt_fn = lambda data: data
        gt = get_gt_fn(data)
        if hasattr(gt, "explanation_nodes_subset") and gt.explanation_nodes_subset is not None:
            distances.append(distance_fn(gt, pred))
        gt.explanation_graph = None
        gt.explanation_nodes_subset = None
    return distances


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
