import torch
import torch_geometric
from torch_geometric.data import Data
import numpy as np
import networkx as nx
import torch_geometric.transforms as T
import matplotlib.pyplot as plt


from collections import Counter, defaultdict
from hashlib import blake2b


def generate_graphs_from_edge_index(
        edge_index, num_nodes, to_undirected=True):
    if to_undirected:
        graph = Data(
            edge_index=torch_geometric.utils.to_undirected(
                torch.tensor(edge_index).type(torch.LongTensor)),
            num_nodes=num_nodes)
    else:
        graph = Data(
            edge_index=torch.tensor(edge_index).type(torch.LongTensor),
            num_nodes=num_nodes)
    
    nx_graph = torch_geometric.utils.to_networkx(graph)
    # edgewise_graph = get_edgewise_graph(graph, to_undirected=ew_to_undirected)
    # nx_edgewise_graph = torch_geometric.utils.to_networkx(edgewise_graph, to_undirected=ew_to_undirected)
    return graph, nx_graph



# __all__ = ["weisfeiler_lehman_graph_hash", "weisfeiler_lehman_subgraph_hashes"]

def _hash_label(label, digest_size):
    return blake2b(label.encode("ascii"), digest_size=digest_size).hexdigest()


def _init_node_labels(G, edge_attr, node_attr):
    if node_attr:
        return {u: str(dd[node_attr]) for u, dd in G.nodes(data=True)}
    elif edge_attr:
        return {u: "" for u in G}
    else:
        # return {u: str(deg) for u, deg in G.degree()}
        return {u: str(deg) for u, deg in G.out_degree()}


def _neighborhood_aggregate(G, node, node_labels, edge_attr=None):
    """
    Compute new labels for given node by aggregating
    the labels of each node's neighbors.
    """
    label_list = []
    for nbr in G.neighbors(node):
        prefix = "" if edge_attr is None else str(G[node][nbr][edge_attr])
        label_list.append(prefix + node_labels[nbr])
    if False:
        print(f"node: {node}")
        print(f"label_list {label_list}")
        print(f"node_labels: {node_labels}")
        print('------------------------')
    return node_labels[node] + "".join(sorted(label_list))

def weisfeiler_lehman_graph_hash(
    G, edge_attr=None, node_attr=None, iterations=3, digest_size=4
):

    def weisfeiler_lehman_step(G, labels, edge_attr=None):
        """
        Apply neighborhood aggregation to each node
        in the graph.
        Computes a dictionary with labels for each node.
        """
        new_labels = {}
        for node in G.nodes():
            label = _neighborhood_aggregate(G, node, labels, edge_attr=edge_attr)
            new_labels[node] = _hash_label(label, digest_size)
        return new_labels

    # set initial node labels
    node_labels = _init_node_labels(G, edge_attr, node_attr)

    subgraph_hash_counts = []
    for _ in range(iterations):
        node_labels = weisfeiler_lehman_step(G, node_labels, edge_attr=edge_attr)
        counter = Counter(node_labels.values())
        # sort the counter, extend total counts
        subgraph_hash_counts.extend(sorted(counter.items(), key=lambda x: x[0]))
        # print(subgraph_hash_counts)

    # hash the final counter
    return _hash_label(str(tuple(subgraph_hash_counts)), digest_size)
