import networkx as nx
import torch
from random import randint
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, Arc
import numpy as np
import os

from torch_geometric.data import Data
from torch_geometric.utils import remove_isolated_nodes, dropout_edge, to_networkx, subgraph
from torch_scatter import scatter_sum, scatter_add

from GOOD.utils.splitting import split_graph
from GOOD.definitions import ROOT_DIR

edge_colors = {
    "inv": "green",
    "spu": "black",
    # "inv": "black",
    # "spu": "green",
    "added": "red"
}
node_colors = {
    True: "red",
    False: "#1f78b4"
}




def get_color_based_on_dataset(config, x):
    if "BAColorGV" in config.dataset.dataset_name or "BAColorRB" in config.dataset.dataset_name:
        if np.argmax(x) == 0:
            return "red"
        elif np.argmax(x) == 1:
            return "blue"
        elif np.argmax(x) == 2:
            return "green"
        elif np.argmax(x) == 3:
            return "violet"
        else:
            return "orange"
    elif config.dataset.dataset_name in ["MNIST", "CPatchMNIST", "CPatchMNIST2"]:
        return x[:3]
    elif config.dataset.dataset_name == "MUTAG":
        atom_type = np.argmax(x)  # Convert one-hot to indices
        # color_map = {atype: plt.cm.tab10(i % 10) for i, atype in enumerate(unique_atom_types)}
        # node_colors = [color_map[atype] for atype in atom_types]
        return plt.cm.tab10(atom_type % 14)
    elif config.dataset.dataset_name == "GraphSST2Planted":
        return "lightblue"

def draw_colored(config, G, name, thrs=None, node_expl=None, edge_expl="", subfolder="", pos=None, save=True, figsize=(6.4, 4.8), nodesize=150, with_labels=True, title=None, ax=None, topk=None):
    plt.figure(figsize=figsize)

    node_gt = list(nx.get_node_attributes(G, "node_gt").values())
    node_attr = list(nx.get_node_attributes(G, "x").values())
    
    if pos is None and config.dataset.dataset_name not in ["MNIST", "CPatchMNIST", "CPatchMNIST2", "GraphSST2Planted"]:
        pos = nx.kamada_kawai_layout(G)
        # pos[len(G.nodes)-2] = np.array([0.8, 0.8]) # useful for BAColorGV
        # pos[len(G.nodes)-1] = np.array([-0.8, -0.8]) # useful for BAColorGV

    elif config.dataset.dataset_name in ["MNIST", "CPatchMNIST", "CPatchMNIST2"]:
        pos = [ (x[4], -x[3])  for x in node_attr]
    elif "SST2" in config.dataset.dataset_name:
        pos = {i: (i*10, (-1)**(i)) for i in range(G.number_of_nodes())}
    
    node_colors = []
    for i in range(len(node_attr)):
        if len(node_gt) > 0 and node_gt[i]:
            node_colors.append("orange") # "lightgreen"
        elif len(node_gt) > 0 and not node_gt[i]:
            node_colors.append("blue") # "lightgreen"
        else:
            node_colors.append(get_color_based_on_dataset(config, node_attr[i]))

    nx.draw(
        G,
        with_labels=with_labels,
        pos=pos,
        ax=ax,
        node_size=nodesize,
        node_color=node_colors,
        edge_color=(0.55, 0.55, 0.55),
        alpha=0.9 if config.dataset.dataset_name in ["MNIST", "CPatchMNIST", "CPatchMNIST2"] else 0.5
    )


    if config.dataset.dataset_name != "MUTAG" and config.dataset.dataset_name != "GraphSST2Planted":
        if node_expl is not None:
            if thrs is not None:
                node_labels = {u: "E" if score >= thrs else "" for u, score in enumerate(node_expl)}
            else:
                node_labels = {u: "E" if is_relev else "" for u, is_relev in enumerate(list(nx.get_node_attributes(G, "node_mask").values()))}
        else:
            assert False, "Not implemented"
            edge_color = list(nx.get_edge_attributes(G, "attn_weight").values())
            edge_color = ["red" if e >= thrs else "black" for e in edge_color]
    elif config.dataset.dataset_name == "GraphSST2Planted":
        sentence_tokens = list(nx.get_node_attributes(G, "sentence_tokens").values())
        node_labels = {i: sentence_tokens[i] for i in range(len(node_attr))}
    else:
        index_to_symbol = {0: 'C', 1: 'O', 2: 'Cl', 3: "H", 4: "N", 5: "F", 6: "Br", 7: "S", 8: "P", 9: "I", 10: "Na", 11: "K", 12: "Li", 13: "Ca"}
        node_labels = {i: index_to_symbol[np.argmax(node_attr[i])] for i in range(len(node_attr))}

    # Annotate nodes with 'E' or other labels
    nx.draw_networkx_labels(
        G,
        pos,
        node_labels,
        font_size=12,
        font_color="red" if config.dataset.dataset_name in ["MNIST", "CPatchMNIST", "CPatchMNIST2"] else "black",
        alpha=0.6
    )

    # Annotate with edge scores
    # if nx.get_edge_attributes(G, 'attn_weight') != {}:
    #     nx.draw_networkx_edge_labels(
    #         G,
    #         pos,
    #         edge_labels=nx.get_edge_attributes(G, 'attn_weight'),
    #         font_size=6,
    #         alpha=0.8
    #     )

    # Annotate with node scores
    if node_expl is not None and pos is not None:
        offset = 0.05
        if isinstance(pos, dict):
            label_pos = {n: (x, y + offset) for n, (x, y) in pos.items()}  # vertical offset
        else:
            label_pos = {n: (x, y + offset) for n, (x, y) in enumerate(pos)}  # vertical offset

        nx.draw_networkx_labels(
            G,
            label_pos,
            labels={n: f"{v:.2f}" for n, v in enumerate(node_expl)},
            font_size=6,
            alpha=0.8
        )

    highlight_nodes = [u for u, is_relev in enumerate(list(nx.get_node_attributes(G, "node_mask").values())) if is_relev]
    nx.draw_networkx_nodes(
        G, 
        pos,
        nodelist=highlight_nodes,
        node_color="yellow",
        node_size=1000,   # larger than the actual node
        alpha=0.3         # transparent halo
    )

    # if topk is not None:
    #     ax = plt.gca()
    #     for node in topk:
    #         x, y = pos[node]
    #         circle = Circle((x, y), radius=0.05, edgecolor='grey', facecolor='none', linewidth=1)
    #         ax.add_patch(circle)
    
    plt.suptitle(title)

    if save:
        path = f'{ROOT_DIR}/GOOD/kernel/pipelines/plots/{subfolder}/{config.load_split}_{config.util_model_dirname}_{config.random_seed}/'
        if not os.path.exists(path):
            try:
                os.makedirs(path)
            except Exception as e:
                print(e)
                exit(e)
        plt.savefig(f'{path}/{name}.pdf')
    else:
        plt.show()

    plt.close()
    return pos

def plot_sentence_graph(G, name, subfolder, config, title):
    """
    Plot a sentence graph (e.g., from GraphSST2).
    
    Args:
        G: PyG Data object with attributes:
           - sentence_tokens: list of strings
           - edge_index: [2, E] tensor
        node_importance: optional 1D array/tensor of length num_nodes
        edge_importance: optional 1D array/tensor of length num_edges
        highlight_threshold: nodes with importance >= threshold will be highlighted
    """
    tokens = list(G.sentence_tokens)
    num_nodes = len(tokens)
    edges = G.edge_index.t().tolist()

    node_importance = G.node_expl

    fig, ax = plt.subplots(figsize=(len(tokens) * 0.7, 2))
    ax.set_axis_off()

    # Horizontal positions
    x_coords = np.arange(num_nodes)
    y_coords = np.zeros(num_nodes)

    # Draw words + highlights
    for i, token in enumerate(tokens):
        if G.node_mask[i]:
            ax.add_patch(plt.Rectangle((x_coords[i] - 0.4, -0.2), 0.8, 0.4,
                                       color='yellow', alpha=0.3, zorder=0))
        ax.text(x_coords[i], 0, token, ha='center', va='center', fontsize=10)
        ax.text(x_coords[i], -1, round(node_importance[i].item(), 2), ha='center', va='center', fontsize=6)

    # Draw edges as arcs
    for idx, (src, tgt) in enumerate(edges):
        if src == tgt:  # skip self-loops
            continue
        # Ensure left-to-right arc
        if src > tgt:
            src, tgt = tgt, src
        x1, x2 = x_coords[src], x_coords[tgt]
        xm = (x1 + x2) / 2
        width = abs(x2 - x1)
        height = width / 2 + 1
        # imp = float(edge_importance[idx])
        arc = Arc((xm, 1), width=width, height=height,
                      theta1=0, theta2=180, color='black',
                      alpha=0.5, lw=0.5)
        ax.add_patch(arc)

    ax.set_xlim(-1, num_nodes)
    ax.set_ylim(-1, max(2, num_nodes/2))
    plt.suptitle(title)
    plt.tight_layout()
    path = f'{ROOT_DIR}/GOOD/kernel/pipelines/plots/{subfolder}/{config.load_split}_{config.util_model_dirname}_{config.random_seed}/'
    
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except Exception as e:
            print(e)
            exit(e)
    plt.savefig(f'{path}/{name}.pdf')
    plt.close()


def fidelity(graph, type):
    """
        Generate the perturbed sample according to Fidelity+ and Fidelity-.
        I.e., either remove the entire explanation, or the entire complement.
        Operationally, we keep the node induced subgraph of either relevant or irrelevant nodes.
    """
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    has_edge_attr = "edge_attr" in graph.keys()
    has_node_is_spurious = "node_is_spurious" in graph.keys()
    
    if type == "fidm":
        # preserve the node induced subgraph of relevant edges
        nodes_to_keep = graph.node_mask
    elif type == "fidp":
        # preserve the node induced subgraph of IRrelevant edges
        nodes_to_keep = torch.logical_not(graph.node_mask)

    edge_index, edge_attr, edge_mask = subgraph(
        nodes_to_keep,
        graph.edge_index,
        edge_attr=graph.edge_attr if has_edge_attr else None,
        return_edge_mask=True,
        relabel_nodes=True,
        num_nodes=graph.x.shape[0]
    )
    return [
        Data(
            x=graph.x[nodes_to_keep],
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_is_spurious=graph.node_is_spurious[nodes_to_keep] if has_node_is_spurious else None,
            y=graph.y,
            node_expl=graph.node_expl[nodes_to_keep],
            node_mask=graph.node_mask[nodes_to_keep],
            edge_mask=graph.edge_mask[edge_mask],
        )
    ]

def robust_fidelity(graph, type, p, expval_budget, inplace=False):
    """
        Generate the perturbed sample according to Robust Fidelity+ and Robust Fidelity-.
        I.e., remove random edges in a IID fashion.
        Partially inspired from https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/utils/dropout.html#dropout_edge

        For Necessity metrics: note that this function considers every edge outside of the complement node-induced subgraph as candidate edge 
        to be removed. Thus, not only the ones strictly inside the explanation node-induced subgraph. 
        For Sufficiency metrics: Specular
    """
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    has_edge_attr= "edge_attr" in graph.keys()
    has_node_is_spurious = "node_is_spurious" in graph.keys()
    
    if type == "rfidm":
        # sample IID for each edge, then force edges inside of R to remain
        nodes_to_keep = graph.node_mask
    elif type == "rfidp":
        # sample IID from the explanation, so get the subgraph induced by the complement
        nodes_to_keep = torch.logical_not(graph.node_mask)

    row, col = graph.edge_index
    _, _, force_to_keep = subgraph(
        nodes_to_keep,
        graph.edge_index,
        return_edge_mask=True,
        relabel_nodes=True,
        num_nodes=graph.x.shape[0]
    )   

    if inplace:
        ret = [graph]
    else:
        ret = [
            Data(
                x=graph.x,
                # edge_index=edge_index,
                edge_attr=None, #graph.edge_attr[idx_kept_edges] if has_edge_attr else None,
                node_is_spurious=graph.node_is_spurious if has_node_is_spurious else None,
                y=graph.y,
                node_expl=graph.node_expl,
                node_mask=graph.node_mask,
                # edge_mask=graph.edge_mask[idx_kept_edges],
            )
            for _ in range(expval_budget)
        ]

    edge_masks = torch.rand((expval_budget, row.size(0)), device=graph.edge_index.device) >= p
    edge_masks[:, force_to_keep] = True # force to keep edges inside R or C, based on the metric
    edge_masks[:, row > col] = False  # force undirected
    all_nonzero = edge_masks.nonzero()
    for j in range(expval_budget):
        edge_mask = edge_masks[j]
        edge_index = graph.edge_index[:, edge_mask]
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)
        
        # idx_kept_edges = edge_mask.nonzero().repeat((2, 1)).squeeze()
        # assert torch.all(edge_mask.nonzero() == all_nonzero[all_nonzero[:, 0] == j][:, 1].reshape(-1, 1))
        # idx_kept_edges_old = all_nonzero[all_nonzero[:, 0] == j][:, 1].reshape(-1, 1).repeat((2, 1)).squeeze()     

        idx_kept_edges = all_nonzero[all_nonzero[:, 0] == j][:, 1].repeat(2).squeeze()     
        
        ret[j].edge_index=edge_index
        ret[j].edge_mask=graph.edge_mask[idx_kept_edges]
        if has_edge_attr:
            ret[j].edge_attr=graph.edge_attr[idx_kept_edges]
    return ret

def nec_budget(graph, avg_graph_size, p, expval_budget):
    """
        Modification of RFID+ to account for irrelevant edges in the explanation.
        From 'https://openreview.net/pdf?id=kiOxNsrpQy'
        Instead of sampling edges IID, sample a fixed budget proportional to the average size of graphs.
        
        Note that this function considers every edge outside of the complement node-induced subgraph as candidate edge 
        to be removed. Thus, not only the ones strictly inside the explanation node-induced subgraph.
    """
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    has_edge_attr= "edge_attr" in graph.keys()
    has_node_is_spurious = "node_is_spurious" in graph.keys()

    row, col = graph.edge_index
    complement_edge_index, _, force_to_keep_complement = subgraph(
        torch.logical_not(graph.node_mask),
        graph.edge_index,
        return_edge_mask=True,
        relabel_nodes=True,
        num_nodes=graph.x.shape[0]
    )

    ret = [
            Data(
                x=graph.x,
                # edge_index=edge_index,
                edge_attr=None, #graph.edge_attr[idx_kept_edges] if has_edge_attr else None,
                node_is_spurious=graph.node_is_spurious if has_node_is_spurious else None,
                y=graph.y,
                node_expl=graph.node_expl,
                node_mask=graph.node_mask,
                # edge_mask=graph.edge_mask[idx_kept_edges],
        )
        for _ in range(expval_budget)
    ] 
    
    # set to False (hence remove) the B edges with highest random weight
    B = min(int(p * avg_graph_size), (graph.edge_index.shape[1] - complement_edge_index.shape[1]))
    
    edge_weights = torch.rand((expval_budget, row.size(0)), device=graph.edge_index.device)
    edge_weights[:, force_to_keep_complement] = -torch.inf # make sure edges in C cannot be chosen
    edge_weights[:, row > col] = -torch.inf # force undirected while ensuring that exactly B edges are removed
    edges_to_remove = torch.topk(edge_weights, k=B, dim=1).indices # B edges with highest value are chosen to be removed

    edges_to_keep = torch.ones_like(edge_weights, device=graph.edge_index.device)
    edges_to_keep.scatter_(index=edges_to_remove, dim=1, value=False)
    edges_to_keep[:, row > col] = False  # force undirected
    edges_to_keep = edges_to_keep.bool()
    
    all_nonzero = edges_to_keep.nonzero()
    for j in range(expval_budget):
        edge_mask = edges_to_keep[j]

        edge_index = graph.edge_index[:, edge_mask]
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)

        idx_kept_edges = all_nonzero[all_nonzero[:, 0] == j][:, 1].repeat(2).squeeze()
        
        ret[j].edge_index=edge_index
        ret[j].edge_mask=graph.edge_mask[idx_kept_edges]
        if has_edge_attr:
            ret[j].edge_attr=graph.edge_attr[idx_kept_edges]
    return ret


def suff_intervent(graph, graph_database, graph_database_labels, expval_budget):
    """
        Interventional SUFF from 'https://openreview.net/pdf?id=kiOxNsrpQy'.
        Randomy attach R from G with C' of a random G'.
        The number of new random edges is the same as the number of edges that 
        were removed from G to R (preserve number of edges, but randomly connect R with C').
    """
    def merge_graphs_randomly(data1: Data, data2: Data, num_random_edges, has_node_is_spurious) -> Data:
        num_nodes_1 = data1.num_nodes
        num_nodes_2 = data2.num_nodes

        # Offset the edge index of the second graph
        edge_index2 = data2.edge_index + num_nodes_1

        # Concatenate edge indices
        merged_edge_index = torch.cat([data1.edge_index, edge_index2], dim=1)

        # Concatenate X
        merged_x = torch.cat([data1.x, data2.x], dim=0)        

        # Concatenate edge features (if available)
        if hasattr(data1, 'edge_attr') and data1.edge_attr is not None:
            merged_edge_attr = torch.cat([data1.edge_attr, data2.edge_attr], dim=0)
        else:
            merged_edge_attr = None

        # Add random edges between the two graphs (avoiding duplicates)
        if num_random_edges > 0:
            all_possible_edges = torch.cartesian_prod(
                torch.arange(num_nodes_1),
                torch.arange(start=num_nodes_1, end=num_nodes_1+num_nodes_2)
            ).to(data1.y.device)

            if has_node_is_spurious:
                # remove edges connecting G/V
                all_possible_edges_is_spurious = torch.cartesian_prod(
                    data1.node_is_spurious,
                    data2.node_is_spurious
                ).to(data1.y.device)
                all_possible_edges = all_possible_edges[all_possible_edges_is_spurious.sum(1) == 0,:]
            random_edges_idxs = torch.randperm(all_possible_edges.shape[0])[:num_random_edges]
            random_edges = all_possible_edges[random_edges_idxs, :].T # size: 2xnum_random_edges

            # Make bidirectional
            bidir_edges = torch.cat([random_edges, random_edges[[1, 0]]], dim=1)
            merged_edge_index = torch.cat([merged_edge_index, bidir_edges], dim=1)

            if merged_edge_attr is not None:
                rand_edge_attr = torch.zeros((bidir_edges.size(1), merged_edge_attr.size(1)), device=data1.y.device)
                # rand_edge_mask = torch.zeros((bidir_edges.size(1), merged_edge_attr.size(1)), device=data1.y.device) # CHECK IF IT IS NEEDED
                merged_edge_attr = torch.cat([merged_edge_attr, rand_edge_attr], dim=0)

        return Data(
            x=merged_x,
            edge_index=merged_edge_index,
            edge_attr=merged_edge_attr,
            node_is_spurious=torch.cat([data1.node_is_spurious, data2.node_is_spurious], dim=0) if has_node_is_spurious else None,
            y=data1.y, # Watch out! this holds only in the invariance setup
            node_expl=torch.cat([data1.node_expl, data2.node_expl], dim=0),
            node_mask=torch.cat([data1.node_mask, data2.node_mask], dim=0),
            edge_mask=torch.cat([data1.edge_mask, data2.edge_mask], dim=0),
        )
    
    def count_boundary_edges(edge_index: torch.Tensor, subset: torch.Tensor, num_nodes: int) -> int:
        # Create a mask for nodes in the subset
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[subset] = True

        src, dst = edge_index

        # An edge is a boundary edge if one end is in subset and the other is not
        in_subset = mask[src]
        in_complement = ~mask[dst]

        # Forward direction: src in subset, dst outside
        boundary_forward = in_subset & in_complement

        # Reverse direction: dst in subset, src outside (for undirected graphs)
        in_subset_rev = mask[dst]
        in_complement_rev = ~mask[src]
        boundary_backward = in_subset_rev & in_complement_rev

        # Combine both directions
        boundary_edges = boundary_forward | boundary_backward
        return boundary_edges.sum().item() // 2 # count only one direction
    
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    has_node_is_spurious = "node_is_spurious" in graph.keys()
    
    # Construct the Data object for R of G
    edge_index, edge_attr, edge_mask = subgraph(
        graph.node_mask,
        graph.edge_index,
        edge_attr=graph.edge_attr if "edge_attr" in graph.keys() else None,
        return_edge_mask=True,
        relabel_nodes=True,
        num_nodes=graph.x.shape[0]
    )
    R = Data(
        x=graph.x[graph.node_mask],
        edge_index=edge_index,
        edge_attr=edge_attr,
        node_is_spurious=graph.node_is_spurious[graph.node_mask] if has_node_is_spurious else None,
        y=graph.y,
        node_expl=graph.node_expl[graph.node_mask],
        node_mask=graph.node_mask[graph.node_mask],
        edge_mask=graph.edge_mask[edge_mask],
    )

    ret = []
    same_class_idx = (graph_database_labels == graph.y.item()).nonzero(as_tuple=True)[0]
    rnd_idxs = torch.randint(0, len(same_class_idx), (expval_budget,))
    num_random_edges = count_boundary_edges(edge_index=graph.edge_index, subset=graph.node_mask, num_nodes=graph.x.size(0))
    for i in range(expval_budget):
        # Sample G' from ANY class
        # Suitable only where the subgraph invariance holds
        # graph_to_merge = graph_database[randint(0, len(graph_database) - 1)]

        # Sample G' from same class as G
        rand_idx = same_class_idx[rnd_idxs[i].item()]
        graph_to_merge = graph_database[rand_idx]

        # Construct the Data object for C' of G'
        edge_index, edge_attr, edge_mask = subgraph(
            torch.logical_not(graph_to_merge.node_mask),
            graph_to_merge.edge_index,
            edge_attr=graph_to_merge.edge_attr if "edge_attr" in graph_to_merge.keys() else None,
            return_edge_mask=True,
            relabel_nodes=True,
            num_nodes=graph_to_merge.x.shape[0]
        )
        C_dash = Data(
            x=graph_to_merge.x[torch.logical_not(graph_to_merge.node_mask)],
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_is_spurious=graph_to_merge.node_is_spurious[torch.logical_not(graph_to_merge.node_mask)] if has_node_is_spurious else None,
            y=graph_to_merge.y,
            node_expl=graph_to_merge.node_expl[torch.logical_not(graph_to_merge.node_mask)],
            node_mask=graph_to_merge.node_mask[torch.logical_not(graph_to_merge.node_mask)],
            edge_mask=graph_to_merge.edge_mask[edge_mask],
        )
        # C_dash = Data(
        #     x=torch.tensor([[0., 0., 0., 9.]], device=graph.y.device),
        #     edge_index=torch.empty(2, 0, dtype=torch.long, device=graph.y.device),
        #     node_is_spurious=torch.tensor([1], device=graph.y.device),
        #     edge_attr=None,
        #     y=None,
        #     node_expl=torch.tensor([9], device=graph.y.device),
        #     node_mask=torch.tensor([9], device=graph.y.device),
        #     edge_mask=torch.tensor([9], device=graph.y.device),
        # )
        
        # Merge R with C'
        ret.append(
            merge_graphs_randomly(R, C_dash, num_random_edges, has_node_is_spurious)
        )
    return ret


def counter_fid(graph, expval_budget):
    """
        Implementation of Counterfacual Fidelity as described in Alg. 1 of 'https://arxiv.org/pdf/2406.07955'.
        Samples random explanation scores with mean and std as given by the explanatory scores of the input.
        The perturbed input has altered attention scores, and needs to be forwarded to the CLF only (line 12 of Alg. 1).
    """
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    if "node_expl" in graph.keys() and not "edge_expl" in graph.keys():
        mean_attn_scores = torch.mean(graph.node_expl)
        std_attn_scores = torch.std(graph.node_expl) + 1e-6
        if torch.isnan(std_attn_scores):
            std_attn_scores = 1e-6
    elif "edge_expl" in graph.keys() and not "node_expl" in graph.keys():
        raise ValueError("edge level explanation not supported")
    else:
        raise ValueError("configuration not suported")
    
    ret = []
    normal_dist = torch.distributions.Normal(loc=mean_attn_scores, scale=std_attn_scores)
    for _ in range(expval_budget):
        ret.append(graph.clone())
        ret[-1].node_expl = normal_dist.sample((graph.node_expl.size(0),)).to(graph.node_expl.device).sigmoid()
    return ret


def suff_cause(graph, expval_budget):
    """
        Our proposed EST metric.
        Remove both nodes and edges, at random.
        First subsample nodes at random. Then, remove edges at random by relying on RFID-.
    """
    if graph.node_mask.sum() == 0: # discard empty explanations
        return None
    
    has_edge_attr= "edge_attr" in graph.keys()
    has_node_is_spurious = "node_is_spurious" in graph.keys()
    
    ret = []
    rnd_weights = torch.rand((expval_budget, graph.x.shape[0]), device=graph.x.device)
    rnd_weights[:, graph.node_mask] = 1.0 # always keep nodes in R
    nodes_to_keep_mask = rnd_weights >= 0.5 # keep nodes with a score >= 0.5 (thus R + other random nodes)
    for i in range(expval_budget):
        # rnd_weights = torch.rand(graph.x.shape[0], device=graph.x.device)
        # rnd_weights[graph.node_mask] = 1.0 # always keep nodes in R
        # nodes_to_keep_mask = rnd_weights >= 0.5 # keep nodes with a score >= 0.5 (thus R + other random nodes)
        # nodes_to_keep = torch.arange(graph.x.shape[0])[nodes_to_keep_mask[i]]
        nodes_to_keep = torch.nonzero(nodes_to_keep_mask[i]).view(-1)

        edge_index, edge_attr, edge_mask = subgraph(
            nodes_to_keep,
            graph.edge_index,
            edge_attr=graph.edge_attr if has_edge_attr else None,
            return_edge_mask=True,
            relabel_nodes=True,
            num_nodes=graph.x.shape[0]
        )

        graph_node_sampled = Data(
            x=graph.x[nodes_to_keep],
            edge_index=edge_index,
            edge_attr=edge_attr,
            node_is_spurious=graph.node_is_spurious[nodes_to_keep] if has_node_is_spurious else None,
            y=graph.y,
            node_expl=graph.node_expl[nodes_to_keep],
            node_mask=graph.node_mask[nodes_to_keep],
            edge_mask=graph.edge_mask[edge_mask],
        )
        graph_node_edge_sampled = robust_fidelity(
            graph_node_sampled,
            type="rfidm",
            p=0.5,
            expval_budget=1,
            inplace=True
        )[0]
        ret.append(graph_node_edge_sampled)
    return ret