import torch
import random
import math
import pdb
import torch.nn as nn
import torch.nn.functional as F
import scipy.sparse as ssp
from scipy import linalg
from scipy.linalg import inv, eig, eigh
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import to_scipy_sparse_matrix
from collections import defaultdict, Counter
from torch_scatter import scatter_min
from batch import Batch
from collections import defaultdict, Counter
import igraph
from tqdm import tqdm
import pynauty
import igraph as ig

def within_graph_color(data):
    node_feat_set = []
    for x_ in data.x:
        x_s = list(x_.numpy())
        if x_s not in node_feat_set:
            node_feat_set.append(x_s)
    node_feat_set_sort = sorted(node_feat_set)
    node_feat_set_sort_str = [str(i) for i in node_feat_set_sort]
    node_feat_dict = {x: i for x, i in zip(node_feat_set_sort_str, range(len(node_feat_set_sort_str)))}
    in_color = []
    for x_ in data.x:
        x_s = str(list(x_.numpy()))
        in_color.append(node_feat_dict[x_s])
    return in_color

def cl_to_orbit(cl, orbits):
    num_orbits = len(set(orbits))
    cl2orb = np.array([0] * len(cl))
    cl_array = np.array(cl)
    for o_id in range(num_orbits):
        idx = [orb == o_id for orb in orbits]
        cl_in_orb = cl_array[idx]
        cl_in_orb_sort = sorted(cl_in_orb)
        dic = {val: pos for val,pos in zip(cl_in_orb_sort, range(len(cl_in_orb_sort)))}
        cl2orb_in_orb = np.array([dic[val] for val in cl_in_orb])
        cl2orb[idx] = cl2orb_in_orb
    return list(cl2orb)

def canlabel_orbitpart(data, direct=False):
    num_node = data.x.shape[0]
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    # get the within graph label
    in_color = within_graph_color(data)
    # get the canonical labels
    g = ig.Graph(directed=direct)
    g.add_vertices(num_node)
    g.add_edges(edges)
    cl = g.canonical_permutation(color=in_color)
    # orbit partition
    num_in_color = len(set(in_color))
    vertex_color_list = {i:[] for i in range(num_in_color)}
    for i in range(num_node):
        vertex_color_list[in_color[i]].append(i)       
    vertex_coloring=[set(vertex_color_list[set_id]) for set_id in range(num_in_color)]
    G = pynauty.Graph(number_of_vertices=num_node, vertex_coloring=vertex_coloring)
    added_edge = []
    for ed in edges:
        i,j = ed
        if direct:
            G.connect_vertex(i, [j])
        else:
            if (j,i) not in added_edge:
                G.connect_vertex(i, [j])
    orbits = pynauty.autgrp(G)[3]
    # relative canonical label on the same orbit
    cl2orb = cl_to_orbit(cl, orbits)
    # remember information
    data.cl2orb = torch.LongTensor(cl2orb)
    data.cl = torch.LongTensor(cl)
    data.orbits = torch.LongTensor(orbits)
    data.num_vertex = num_node 
    return data



def canlabel(data, direct=False):
    # 
    x = data.x
    if x == None:
        num_node = data.edge_index.max().item() + 1
    else:
        num_node = data.x.shape[0]
    
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    # get the within graph label
    in_color = within_graph_color(data)
    # get the canonical labels
    g = ig.Graph(directed=direct)
    g.add_vertices(num_node)
    g.add_edges(edges)
    if x == None:
        cl = g.canonical_permutation()
    else:     
        cl = g.canonical_permutation(color=in_color)
    # relative canonical label on the same orbit
    #cl2orb = cl_to_orbit(cl, orbits)
    # remember information
    #data.cl2orb = torch.LongTen1sor(cl2orb)
    data.cl = torch.LongTensor(cl)
    #data.orbits = torch.LongTensor(orbits)
    data.num_vertex = num_node 
    return data

def canlabel(data, direct=False):
    num_node = data.x.shape[0]
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    # get the within graph label
    in_color = within_graph_color(data)
    # get the canonical labels
    g = ig.Graph(directed=direct)
    g.add_vertices(num_node)
    g.add_edges(edges)
    cl = g.canonical_permutation(color=in_color)
    # relative canonical label on the same orbit
    #cl2orb = cl_to_orbit(cl, orbits)
    # remember information
    #data.cl2orb = torch.LongTensor(cl2orb)
    data.cl = torch.LongTensor(cl)
    #data.orbits = torch.LongTensor(orbits)
    data.num_vertex = num_node 
    return data


def orb_cl_counter(dataset):
    pbar = tqdm(range(len(dataset)))
    max_cl2orb = 1
    max_orb = 1
    for i in pbar:
        data = canlabel_orbitpart(dataset[i])
        cl2orb_ = torch.max(data.cl2orb).item()
        orb_ = torch.max(data.orbits).item()
        max_cl2orb = max(max_cl2orb, cl2orb_)
        max_orb = max(max_orb, orb_)
    return max_orb+1, max_cl2orb+1

def cl_counter(dataset):
    pbar = tqdm(range(len(dataset)))
    max_cl = 1
    for i in pbar:
        data = canlabel(dataset[i])
        cl= torch.max(data.cl).item()
        max_cl = max(max_cl, cl)
    return max_cl

def k_hop_subgraph(node_idx, num_hops, edge_index, relabel_nodes=False,
                   num_nodes=None, flow='source_to_target', node_label='hop',
                   max_nodes_per_hop=None):
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    assert flow in ['source_to_target', 'target_to_source']
    if flow == 'target_to_source':
        row, col = edge_index
    else:
        col, row = edge_index

    node_mask = row.new_empty(num_nodes, dtype=torch.bool)
    edge_mask = row.new_empty(row.size(0), dtype=torch.bool)

    subsets = [torch.tensor([node_idx], device=row.device).flatten()]
    visited = set(subsets[-1].tolist())
    label = defaultdict(list)
    for node in subsets[-1].tolist():
        label[node].append(1)
    if node_label == 'hop':
        hops = [torch.LongTensor([0], device=row.device).flatten()]
    for h in range(num_hops):
        node_mask.fill_(False)
        node_mask[subsets[-1]] = True
        torch.index_select(node_mask, 0, row, out=edge_mask)
        new_nodes = col[edge_mask]
        tmp = []
        for node in new_nodes.tolist():
            if node in visited:
                continue
            tmp.append(node)
            label[node].append(h + 2)
        if len(tmp) == 0:
            break
        if max_nodes_per_hop is not None:
            if max_nodes_per_hop < len(tmp):
                tmp = random.sample(tmp, max_nodes_per_hop)
        new_nodes = set(tmp)
        visited = visited.union(new_nodes)
        new_nodes = torch.tensor(list(new_nodes), device=row.device)
        subsets.append(new_nodes)
        if node_label == 'hop':
            hops.append(torch.LongTensor([h + 1] * len(new_nodes), device=row.device))
    subset = torch.cat(subsets)
    inverse_map = torch.tensor(range(subset.shape[0]))
    if node_label == 'hop':
        hop = torch.cat(hops)
    # Add `node_idx` to the beginning of `subset`.
    subset = subset[subset != node_idx]
    subset = torch.cat([torch.tensor([node_idx], device=row.device), subset])

    z = None
    if node_label == 'hop':
        hop = hop[hop != 0]
        hop = torch.cat([torch.LongTensor([0], device=row.device), hop])
        z = hop.unsqueeze(1)
    elif node_label.startswith('spd') or node_label == 'drnl':
        if node_label.startswith('spd'):
            # keep top k shortest-path distances
            num_spd = int(node_label[3:]) if len(node_label) > 3 else 2
            z = torch.zeros(
                [subset.size(0), num_spd], dtype=torch.long, device=row.device
            )
        elif node_label == 'drnl':
            # see "Link Prediction Based on Graph Neural Networks", a special
            # case of spd2
            num_spd = 2
            z = torch.zeros([subset.size(0), 1], dtype=torch.long, device=row.device)

        for i, node in enumerate(subset.tolist()):
            dists = label[node][:num_spd]  # keep top num_spd distances
            if node_label == 'spd':
                z[i][:min(num_spd, len(dists))] = torch.tensor(dists)
            elif node_label == 'drnl':
                dist1 = dists[0]
                dist2 = dists[1] if len(dists) == 2 else 0
                if dist2 == 0:
                    dist = dist1
                else:
                    dist = dist1 * (num_hops + 1) + dist2
                z[i][0] = dist

    node_mask.fill_(False)
    node_mask[subset] = True
    edge_mask = node_mask[row] & node_mask[col]

    edge_index = edge_index[:, edge_mask]

    if relabel_nodes:  # GOOD CODING
        node_idx = row.new_full((num_nodes,), -1)
        node_idx[subset] = torch.arange(subset.size(0), device=row.device)
        edge_index = node_idx[edge_index]

    return subset, edge_index, edge_mask, z


def maybe_num_nodes(index, num_nodes=None):
    return index.max().item() + 1 if num_nodes is None else num_nodes


def khop_feature_trans(data, h=1, sample_ratio=1.0, max_nodes_per_hop=None,
                       node_label='hop', use_rd=False, use_ss=False, bound_list=[5, 10, 15]):
    assert (isinstance(data, Data))
    x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes

    # subgraph_x = [] # Since x need to use another encoder to transfer, here we simply remember the index to redue the operations
    subg_nodes = []
    subg_nodes_seq = []
    subg_edges = []
    subg_masks = []
    # subg_size = []
    # subgraph_adj = [] since adj need to get the edge encoding, we again use the feature transfer
    subg_rd = []
    # subgraph_ss = [] # now we do not consider the steady state

    # feat_dim = x.shape[1] # THIS MIGHT NEED TO BE FIXED
    total_nodes = sum(bound_list) + 1
    data.total_seq_nodes = total_nodes
    data.number_edges = edge_index.size(1)
    subg_num_nodes = []
    subg_num_edges = []
    subg_mask_size = []

    for ind in range(num_nodes):
        nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
            ind, h, edge_index, True, num_nodes, node_label=node_label,
            max_nodes_per_hop=max_nodes_per_hop
        )
        seq_index, seq_edge_index = seq_label_trans(z_, edge_index_, bound_list)

        subg_num_nodes.append(nodes_.size()[0])
        subg_num_edges.append(sum(edge_mask_).item())
        subg_mask_size.append(edge_mask_.size()[0])

        subg_nodes.append(nodes_)  # extrat node embeddings
        subg_nodes_seq.append(seq_index)  # positions in the sequence

        # subg_edges.append(seq_edge_index.flatten())  # positions in the sequence
        subg_edges.append(seq_edge_index)
        subg_masks.append(edge_mask_)  # extract edge embeddings, size = all_number_edges
        # subg_size.append([nodes_.size(0)])

        if use_rd:
            # See "Link prediction in complex networks: A survey".
            adj = to_scipy_sparse_matrix(
                edge_index_, num_nodes=nodes_.shape[0]
            ).tocsr()
            laplacian = ssp.csgraph.laplacian(adj).toarray()
            try:
                L_inv = linalg.pinv(laplacian)
            except:
                laplacian += 0.01 * np.eye(*laplacian.shape)
            lxx = L_inv[0, 0]
            lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
            lxy = L_inv[0, :]
            lyx = L_inv[:, 0]
            rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
            subg_rd.append(rd_to_x.squeeze())

    data.subg_nodes = torch.cat(subg_nodes, dim=0)
    data.subg_nodes_seq = torch.cat(subg_nodes_seq, dim=0)
    data.subg_edges = torch.cat(subg_edges, dim=1).reshape(-1, 2)
    data.subg_masks = torch.cat(subg_masks, dim=0)
    if use_rd:
        data.subg_rd = torch.cat(subg_rd, dim=0)
    data.subg_node_size = subg_num_nodes
    data.subg_edge_size = subg_num_edges
    data.subg_mask_size = subg_mask_size

    return data


def seq_label_trans(z, edge_index, bound_list):
    # based on the distance to the root node and the canonical labels to linearize the node sequence.
    # return a longtensor of index of nodes in the padded sequence.
    # z: shape = (num_node, 1)
    hop = len(bound_list)
    z = z.squeeze().tolist()
    if type(z) is int:
        z = [z]
    num_node = len(z)
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    g = igraph.Graph(directed=False)
    g.add_vertices(num_node)
    g.add_edges(edges)
    cl = g.canonical_permutation(color = z)
    padded_label = [0] * num_node
    accum = 1
    bound_accum = 1
    for h in range(1,hop+1):
        hop_nodes = sum(np.array(z) == h)
        for idx in range(accum, accum + hop_nodes):
            padded_label[idx] = cl[idx] -  accum + bound_accum
        accum += hop_nodes
        bound_accum += bound_list[h-1]
    seq_index = torch.LongTensor(padded_label) # map the original index to the sequence index
    seq_edge_index = seq_index[edge_index] # relabel the edges in the sequence
    return seq_index, seq_edge_index



def khop_feature_trans_v2(data, h=1, sample_ratio=1.0, max_nodes_per_hop=None,
                       node_label='hop', use_rd=False, same_length=False, bound_list=[5, 10, 15]):
    assert (isinstance(data, Data))
    x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes

    # subgraph_x = [] # Since x need to use another encoder to transfer, here we simply remember the index to redue the operations
    subg_nodes = []
    subg_nodes_seq = []
    subg_edges = []
    subg_masks = []
    # subg_size = []
    # subgraph_adj = [] since adj need to get the edge encoding, we again use the feature transfer
    subg_rd = []
    # subgraph_ss = [] # now we do not consider the steady state

    # feat_dim = x.shape[1] # THIS MIGHT NEED TO BE FIXED
    total_nodes = sum(bound_list) + 1
    data.total_seq_nodes = total_nodes
    data.number_edges = edge_index.size(1)
    subg_num_nodes = []
    subg_num_edges = []
    subg_mask_size = []
    subg_cls = []
    subg_zs = []
    subg_seq_degs = []

    for ind in range(num_nodes):
        nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
            ind, h, edge_index, True, num_nodes, node_label=node_label,
            max_nodes_per_hop=max_nodes_per_hop
        )
        seq_index, seq_edge_index,  cl, deg_class = seq_label_trans_v2(z_, edge_index_, bound_list)

        subg_num_nodes.append(nodes_.size()[0])
        subg_num_edges.append(sum(edge_mask_).item())
        subg_mask_size.append(edge_mask_.size()[0])

        subg_nodes.append(nodes_)  # extrat node embeddings
        subg_nodes_seq.append(seq_index)  # positions in the sequence
        subg_cls.append(cl)
        subg_zs.append(z_)
        subg_seq_degs.append(deg_class)

        # subg_edges.append(seq_edge_index.flatten())  # positions in the sequence
        subg_edges.append(seq_edge_index)
        subg_masks.append(edge_mask_)  # extract edge embeddings, size = all_number_edges
        # subg_size.append([nodes_.size(0)])

        if use_rd:
            # See "Link prediction in complex networks: A survey".
            if nodes_.size()[0] > 1:
                adj = to_scipy_sparse_matrix(
                    edge_index_, num_nodes=nodes_.shape[0]
                ).tocsr()
                laplacian = ssp.csgraph.laplacian(adj).toarray()
                try:
                    L_inv = linalg.pinv(laplacian)
                except:
                    laplacian += 0.01 * np.eye(*laplacian.shape)
                lxx = L_inv[0, 0]
                lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
                lxy = L_inv[0, :]
                lyx = L_inv[:, 0]
                rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
                subg_rd.append(rd_to_x.squeeze())
            else:
                rd_to_x = torch.FloatTensor([0])
                subg_rd.append(rd_to_x)

    data.subg_nodes = torch.cat(subg_nodes, dim=0)
    data.subg_nodes_seq = torch.cat(subg_nodes_seq, dim=0)
    data.subg_edges = torch.cat(subg_edges, dim=1).reshape(-1, 2)
    data.subg_masks = torch.cat(subg_masks, dim=0)
    if use_rd:
        data.subg_rd = torch.cat(subg_rd, dim=0)
    data.subg_node_size = subg_num_nodes
    data.subg_edge_size = subg_num_edges
    data.subg_mask_size = subg_mask_size
    data.zs = torch.cat(subg_zs, dim=0)
    data.cls = torch.cat(subg_cls, dim=0)
    data.hop_degs = torch.cat(subg_seq_degs, dim=0)

    return data

def seq_label_trans_v2(z, edge_index, bound_list):
    # based on the distance to the root node and the canonical labels to linearize the node sequence.
    # return a longtensor of index of nodes in the padded sequence.
    # z: shape = (num_node, 1)
    hop = len(bound_list)
    z = z.squeeze().tolist()
    if type(z) is int:
        z = [z]
    num_node = len(z)
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    g = igraph.Graph(directed=False)
    g.add_vertices(num_node)
    g.add_edges(edges)
    cl = g.canonical_permutation(color = z)
    adj = np.array(list(g.get_adjacency()))
    deg_class = np.zeros((num_node, 3)) # count number of adjacent nodes within h-1 hop, h hop, h + 1 hop
    padded_label = [0] * num_node
    hop_cl = [torch.LongTensor([0])]

    accum = 1
    bound_accum = 1
    deg_class[0,2] = sum(np.array(z) == 1)
    for h in range(1,hop+1): # hop h nodes
        hop_nodes = sum(np.array(z) == h)
        cl_inhop = torch.LongTensor(cl[accum: accum + hop_nodes])
        inhop_order = torch.argsort(cl_inhop)
        hop_cl.append(inhop_order)
        for idx in range(accum, accum + hop_nodes):
            padded_label[idx] = cl[idx] -  accum + bound_accum
            seq0_hop_idx = np.array(z) == h-1
            seq1_hop_idx = np.array(z) == h
            deg_class[idx, 0] = np.sum(adj[idx, seq0_hop_idx])
            deg_class[idx, 1] = np.sum(adj[idx, seq1_hop_idx])
            if h < hop:
                seq2_hop_idx = np.array(z) == h+1
                deg_class[idx, 2] = np.sum(adj[idx, seq2_hop_idx])
        accum += hop_nodes
        bound_accum += bound_list[h-1]
    seq_index = torch.LongTensor(padded_label) # map the original index to the sequence index
    seq_edge_index = seq_index[edge_index] # relabel the edges in the sequence
    return seq_index, seq_edge_index, torch.cat(hop_cl), torch.LongTensor(deg_class)

"""Multilabel canonical labeling GNN"""

def tuplabel_to_seqlabel(label_dict):
    sort_dict = sorted(label_dict.items(), key=lambda t: t[1])
    total_length = np.sum(np.array(list(label_dict.values())))
    pbar = tqdm(range(len(sort_dict)))
    seqlabel_dict = {}

    accum = 0
    # only track the start index
    for i in pbar:
        key = sort_dict[i][0]
        val = sort_dict[i][1]
        # if val == 1:
        #    seqlabel_dict[key] = [accum]
        # else:
        #    seqlabel_dict[key] = [accum, accum + val - 1]
        seqlabel_dict[key] = accum
        accum += val
    return seqlabel_dict, total_length


def hopdeg_labeling(dataset, h=3):
    # label = (hop, degree, #nod in pre hop, #nod in cur hop, #node in nxt hop, #nod inducr subg
    # #peri edges in cur hop, #peri edges in nxt hop, #edges inducr subg)
    label_key_list = []
    label_dict = {}  # dict of list
    pbar = tqdm(range(len(dataset)))
    for i in pbar:  # graph traveling
        data = dataset[i]
        x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes
        for ind in range(num_nodes):  # subgraph traveling
            # begin from here, we are manipulating subraphs
            nodes_, edge_index_, edge_mask_, z = k_hop_subgraph(ind, h, edge_index, True, num_nodes, node_label='hop')
            z = z.squeeze().tolist()
            if type(z) is int:
                z = [z]
            num_node_ = len(z)
            edges_ = [(i, j) for i, j in zip(edge_index_[0], edge_index_[1])]
            g = igraph.Graph(directed=False)
            g.add_vertices(num_node_)
            g.add_edges(edges_)
            # cl = g.canonical_permutation(color = z)
            adj = np.array(list(g.get_adjacency()))
            two_adj = np.matmul(adj, adj)
            deg = g.degree()  # degree

            ingraph_label_key_list = []
            ingraph_label_dict = {}

            accum = 1
            for hop in range(1, h + 1):  # in-subgraph counting
                # dc0 = hop
                label_tuple = []
                hop_nodes = sum(np.array(z) == hop)
                # cl_inhop = torch.LongTensor(cl[accum: accum + hop_nodes])
                # inhop_order = torch.argsort(cl_inhop)
                for idx in range(accum, accum + hop_nodes):
                    label_tuple.append(z[idx])
                    label_tuple.append(deg[idx])

                    # connected nods in pre, cur, nex hop
                    seq0_hop_idx = np.array(z) == hop - 1
                    seq1_hop_idx = np.array(z) == hop
                    dc1 = np.sum(adj[idx, seq0_hop_idx])
                    dc2 = np.sum(adj[idx, seq1_hop_idx])
                    # num nod in induced subg
                    padded = np.array([False] * num_node_)
                    padded[accum:] = True
                    induc_n = np.sum(two_adj[idx, padded] > 0.5)
                    # num peri edges in cur, nxt hop
                    padded = np.array([False] * num_node_)
                    padded[seq1_hop_idx] = adj[idx, seq1_hop_idx] > 0.5
                    peri2 = np.sum(adj[:, padded][padded, :])
                    # induc subg edges num
                    # padded = np.array([False] * num_node_)
                    # padded[accum:] = True
                    # induc_e = np.sum(adj[:,padded][padded,:])
                    if hop < h:
                        seq2_hop_idx = np.array(z) == hop + 1
                        dc3 = np.sum(adj[idx, seq2_hop_idx])
                        padded = np.array([False] * num_node_)
                        padded[seq2_hop_idx] = adj[idx, seq2_hop_idx] > 0.5
                        peri3 = np.sum(adj[:, padded][padded, :])
                    else:
                        dc3 = 0
                        peri3 = 0

                    label_tuple.append(dc1)
                    label_tuple.append(dc2)
                    label_tuple.append(dc3)
                    label_tuple.append(induc_n)
                    label_tuple.append(peri2)
                    label_tuple.append(peri3)
                    # label_tuple.append(induc_e)

                    label_tuple = tuple(label_tuple)
                    if label_tuple not in ingraph_label_key_list:
                        # if label_tuple not in ingraph_label_dict.keys():
                        ingraph_label_key_list.append(label_tuple)
                        ingraph_label_dict[label_tuple] = 1
                    else:
                        ingraph_label_dict[label_tuple] += 1

                    label_tuple = []
                accum += hop_nodes

            for key in ingraph_label_key_list:
                if key not in label_key_list:
                    label_key_list.append(key)
                    label_dict[key] = ingraph_label_dict[key]
                else:
                    label_dict[key] = max(ingraph_label_dict[key], label_dict[key])
    return label_dict


def khop_feature_trans_mulabel(data, h=3, sample_ratio=1.0, max_nodes_per_hop=None,
                          node_label='hop', use_rd=False, use_ss=False, seqlabel_dict=None, seq_len=100):
    assert (isinstance(data, Data))
    x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes

    # subgraph_x = [] # Since x need to use another encoder to transfer, here we simply remember the index to redue the operations
    subg_nodes = []
    subg_nodes_seq = []
    subg_edges = []
    subg_masks = []
    # subg_size = []
    # subgraph_adj = [] since adj need to get the edge encoding, we again use the feature transfer
    subg_rd = []
    # subgraph_ss = [] # now we do not consider the steady state

    # feat_dim = x.shape[1] # THIS MIGHT NEED TO BE FIXED
    total_nodes = seq_len + 1  # 1 is used for the hop node
    data.total_seq_nodes = total_nodes
    data.number_edges = edge_index.size(1)
    subg_num_nodes = []
    subg_num_edges = []
    subg_mask_size = []

    for ind in range(num_nodes):
        nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
            ind, h, edge_index, True, num_nodes, node_label=node_label,
            max_nodes_per_hop=max_nodes_per_hop
        )
        seq_index, seq_edge_index = seq_label_trans_mulabel(
            z = z_,
            edge_index = edge_index_,
            h = h,
            seqlabel_dict = seqlabel_dict)

        subg_num_nodes.append(nodes_.size()[0])
        if type(sum(edge_mask_)) != int:
            subg_edges_count = sum(edge_mask_).item()
        else:
            subg_edges_count = sum(edge_mask_)
        subg_num_edges.append(subg_edges_count)
        subg_mask_size.append(edge_mask_.size()[0])

        subg_nodes.append(nodes_)  # extrat node embeddings
        subg_nodes_seq.append(seq_index)  # positions in the sequence

        # subg_edges.append(seq_edge_index.flatten())  # positions in the sequence
        subg_edges.append(seq_edge_index)
        subg_masks.append(edge_mask_)  # extract edge embeddings, size = all_number_edges
        # subg_size.append([nodes_.size(0)])

        if use_rd:
            if nodes_.size()[0] > 1:
                adj = to_scipy_sparse_matrix(
                    edge_index_, num_nodes=nodes_.shape[0]
                ).tocsr()
                laplacian = ssp.csgraph.laplacian(adj).toarray()
                try:
                    L_inv = linalg.pinv(laplacian)
                except:
                    laplacian += 0.01 * np.eye(*laplacian.shape)
                lxx = L_inv[0, 0]
                lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
                lxy = L_inv[0, :]
                lyx = L_inv[:, 0]
                rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
                subg_rd.append(rd_to_x.squeeze())
            else:
                rd_to_x = torch.FloatTensor([0])
                subg_rd.append(rd_to_x)

    data.subg_nodes = torch.cat(subg_nodes, dim=0)
    data.subg_nodes_seq = torch.cat(subg_nodes_seq, dim=0)
    data.subg_edges = torch.cat(subg_edges, dim=1).reshape(-1, 2)
    data.subg_masks = torch.cat(subg_masks, dim=0)
    if use_rd:
        data.subg_rd = torch.cat(subg_rd, dim=0)
    data.subg_node_size = subg_num_nodes
    data.subg_edge_size = subg_num_edges
    data.subg_mask_size = subg_mask_size

    return data


def seq_label_trans_mulabel(z, edge_index, h, seqlabel_dict):
    # based on the distance to the root node and the canonical labels to linearize the node sequence.
    # return a longtensor of index of nodes in the padded sequence.
    # z: shape = (num_node, 1)
    z = z.squeeze().tolist()
    # print(z)
    if type(z) is int:
        z = [z]
    num_node_ = len(z)
    edges_ = [(i, j) for i, j in zip(edge_index[0], edge_index[1])]
    g = igraph.Graph(directed=False)
    g.add_vertices(num_node_)
    g.add_edges(edges_)
    cl = g.canonical_permutation(color=z)
    adj = np.array(list(g.get_adjacency()))
    two_adj = np.matmul(adj, adj)
    deg = g.degree()  # degree

    padded_label = [0] * num_node_
    # store_same_label_length = [0] * num_node_
    accum = 1
    for hop in range(1, h + 1):  # in-subgraph counting
        # dc0 = hop
        label_tuple = []
        hop_nodes = sum(np.array(z) == hop)
        # cl_inhop = torch.LongTensor(cl[accum: accum + hop_nodes])
        # inhop_order = torch.argsort(cl_inhop)
        for idx in range(accum, accum + hop_nodes):
            label_tuple.append(z[idx])
            label_tuple.append(deg[idx])

            # connected nods in pre, cur, nex hop
            seq0_hop_idx = np.array(z) == hop - 1
            seq1_hop_idx = np.array(z) == hop
            dc1 = np.sum(adj[idx, seq0_hop_idx])
            dc2 = np.sum(adj[idx, seq1_hop_idx])
            # num nod in induced subg
            padded = np.array([False] * num_node_)
            padded[accum:] = True
            induc_n = np.sum(two_adj[idx, padded] > 0.5)
            # num peri edges in cur, nxt hop
            padded = np.array([False] * num_node_)
            padded[seq1_hop_idx] = adj[idx, seq1_hop_idx] > 0.5
            peri2 = np.sum(adj[:, padded][padded, :])
            # induc subg edges num
            # padded = np.array([False] * num_node_)
            # padded[accum:] = True
            # induc_e = np.sum(adj[:,padded][padded,:])
            if hop < h:
                seq2_hop_idx = np.array(z) == hop + 1
                dc3 = np.sum(adj[idx, seq2_hop_idx])
                padded = np.array([False] * num_node_)
                padded[seq2_hop_idx] = adj[idx, seq2_hop_idx] > 0.5
                peri3 = np.sum(adj[:, padded][padded, :])
            else:
                dc3 = 0
                peri3 = 0

            label_tuple.append(dc1)
            label_tuple.append(dc2)
            label_tuple.append(dc3)
            label_tuple.append(induc_n)
            label_tuple.append(peri2)
            label_tuple.append(peri3)
            # label_tuple.append(induc_e)

            label_tuple = tuple(label_tuple)
            # print(idx)
            # print(label_tuple)
            seq_label = seqlabel_dict[label_tuple]
            padded_label[idx] = seq_label
            # store_same_label_length[idx] =
            label_tuple = []
        accum += hop_nodes

    seq_index = torch.LongTensor(padded_label)
    cl = torch.LongTensor(cl)
    for i in seq_index:
        idx = seq_index == i
        if torch.sum(idx).item() > 1:
            cl_to_sort = cl[idx]
            sort = torch.argsort(cl_to_sort)
            seq_index[idx] += sort
    seq_edge_index = seq_index[edge_index]
    return seq_index, seq_edge_index


def neighbors(fringe, A):
    # Find all 1-hop neighbors of nodes in fringe from A
    res = set()
    for node in fringe:
        _, out_nei, _ = ssp.find(A[node, :])
        in_nei, _, _ = ssp.find(A[:, node])
        nei = set(out_nei).union(set(in_nei))
        res = res.union(nei)
    return res


class return_prob(object):
    def __init__(self, steps=50):
        self.steps = steps

    def __call__(self, data):
        adj = to_scipy_sparse_matrix(data.edge_index, num_nodes=data.num_nodes).tocsr()
        adj += ssp.identity(data.num_nodes, dtype='int', format='csr')
        rp = np.empty([data.num_nodes, self.steps])
        inv_deg = ssp.lil_matrix((data.num_nodes, data.num_nodes))
        inv_deg.setdiag(1 / adj.sum(1))
        P = inv_deg * adj
        if self.steps < 5:
            Pi = P
            for i in range(self.steps):
                rp[:, i] = Pi.diagonal()
                Pi = Pi * P
        else:
            inv_sqrt_deg = ssp.lil_matrix((data.num_nodes, data.num_nodes))
            inv_sqrt_deg.setdiag(1 / (np.array(adj.sum(1)) ** 0.5))
            B = inv_sqrt_deg * adj * inv_sqrt_deg
            L, U = eigh(B.todense())
            W = U * U
            Li = L
            for i in range(self.steps):
                rp[:, i] = W.dot(Li)
                Li = Li * L

        data.rp = torch.FloatTensor(rp)

        return data


def encode_y_to_arr(data, vocab2idx, max_seq_len):
    '''
    Input:
        data: PyG graph object
        output: add y_arr to data
    '''

    # PyG >= 1.5.0
    seq = data.y

    # PyG = 1.4.3
    # seq = data.y[0]

    data.y_arr = encode_seq_to_arr(seq, vocab2idx, max_seq_len)

    return data


def encode_seq_to_arr(seq, vocab2idx, max_seq_len):
    '''
    Input:
        seq: A list of words
        output: add y_arr (torch.Tensor)
    '''

    augmented_seq = seq[:max_seq_len] + ['__EOS__'] * max(0, max_seq_len - len(seq))
    return torch.tensor([[vocab2idx[w] if w in vocab2idx else vocab2idx['__UNK__'] for w in augmented_seq]],
                        dtype=torch.long)


def get_vocab_mapping(seq_list, num_vocab):
    '''
        Input:
            seq_list: a list of sequences
            num_vocab: vocabulary size
        Output:
            vocab2idx:
                A dictionary that maps vocabulary into integer index.
                Additioanlly, we also index '__UNK__' and '__EOS__'
                '__UNK__' : out-of-vocabulary term
                '__EOS__' : end-of-sentence

            idx2vocab:
                A list that maps idx to actual vocabulary.

    '''

    vocab_cnt = {}
    vocab_list = []
    for seq in seq_list:
        for w in seq:
            if w in vocab_cnt:
                vocab_cnt[w] += 1
            else:
                vocab_cnt[w] = 1
                vocab_list.append(w)

    cnt_list = np.array([vocab_cnt[w] for w in vocab_list])
    topvocab = np.argsort(-cnt_list, kind='stable')[:num_vocab]

    print('Coverage of top {} vocabulary:'.format(num_vocab))
    print(float(np.sum(cnt_list[topvocab])) / np.sum(cnt_list))

    vocab2idx = {vocab_list[vocab_idx]: idx for idx, vocab_idx in enumerate(topvocab)}
    idx2vocab = [vocab_list[vocab_idx] for vocab_idx in topvocab]

    # print(topvocab)
    # print([vocab_list[v] for v in topvocab[:10]])
    # print([vocab_list[v] for v in topvocab[-10:]])
    n = len(vocab2idx)
    vocab2idx['__UNK__'] = n  # num_vocab
    idx2vocab.append('__UNK__')

    vocab2idx['__EOS__'] = n + 1  # num_vocab + 1
    idx2vocab.append('__EOS__')

    # test the correspondence between vocab2idx and idx2vocab
    for idx, vocab in enumerate(idx2vocab):
        assert (idx == vocab2idx[vocab])

    # test that the idx of '__EOS__' is len(idx2vocab) - 1.
    # This fact will be used in decode_arr_to_seq, when finding __EOS__
    assert (vocab2idx['__EOS__'] == len(idx2vocab) - 1)

    return vocab2idx, idx2vocab


def augment_edge(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''

    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim=0)
    edge_attr_ast_inverse = torch.cat(
        [torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim=1)

    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1, ) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]],
                                      dim=0)
    edge_attr_nextoken = torch.cat(
        [torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim=1)

    ##### Inverse next-token edge
    edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim=0)
    edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))

    data.edge_index = torch.cat(
        [edge_index_ast, edge_index_ast_inverse, edge_index_nextoken, edge_index_nextoken_inverse], dim=1)
    data.edge_attr = torch.cat([edge_attr_ast, edge_attr_ast_inverse, edge_attr_nextoken, edge_attr_nextoken_inverse],
                               dim=0)

    return data


def augment_edge2(data):
    '''
        Input:
            data: PyG data object
        Output:
            data (edges are augmented in the following ways):
                data.edge_index: Added next-token edge. The inverse edges were also added.
                data.edge_attr (torch.Long):
                    data.edge_attr[:,0]: whether it is AST edge (0) for next-token edge (1)
                    data.edge_attr[:,1]: whether it is original direction (0) or inverse direction (1)
    '''

    ##### AST edge
    edge_index_ast = data.edge_index
    edge_attr_ast = torch.zeros((edge_index_ast.size(1), 2))

    ##### Inverse AST edge
    # edge_index_ast_inverse = torch.stack([edge_index_ast[1], edge_index_ast[0]], dim = 0)
    # edge_attr_ast_inverse = torch.cat([torch.zeros(edge_index_ast_inverse.size(1), 1), torch.ones(edge_index_ast_inverse.size(1), 1)], dim = 1)

    ##### Next-token edge

    ## Obtain attributed nodes and get their indices in dfs order
    # attributed_node_idx = torch.where(data.node_is_attributed.view(-1,) == 1)[0]
    # attributed_node_idx_in_dfs_order = attributed_node_idx[torch.argsort(data.node_dfs_order[attributed_node_idx].view(-1,))]

    ## Since the nodes are already sorted in dfs ordering in our case, we can just do the following.
    attributed_node_idx_in_dfs_order = torch.where(data.node_is_attributed.view(-1, ) == 1)[0]

    ## build next token edge
    # Given: attributed_node_idx_in_dfs_order
    #        [1, 3, 4, 5, 8, 9, 12]
    # Output:
    #    [[1, 3, 4, 5, 8, 9]
    #     [3, 4, 5, 8, 9, 12]
    edge_index_nextoken = torch.stack([attributed_node_idx_in_dfs_order[:-1], attributed_node_idx_in_dfs_order[1:]],
                                      dim=0)
    edge_attr_nextoken = torch.cat(
        [torch.ones(edge_index_nextoken.size(1), 1), torch.zeros(edge_index_nextoken.size(1), 1)], dim=1)

    ##### Inverse next-token edge
    # edge_index_nextoken_inverse = torch.stack([edge_index_nextoken[1], edge_index_nextoken[0]], dim = 0)
    # edge_attr_nextoken_inverse = torch.ones((edge_index_nextoken.size(1), 2))

    data.edge_index = torch.cat([edge_index_ast, edge_index_nextoken], dim=1)
    data.edge_attr = torch.cat([edge_attr_ast, edge_attr_nextoken], dim=0)

    return data

def canonical_label1(data):
    # get the canonical label for code2
    num_node = data.x.shape[0]
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    g = igraph.Graph(directed=False)
    g.add_vertices(num_node)
    g.add_edges(edges)
    cl = g.canonical_permutation()
    data.cl = cl
    data.num_vertex = num_node
    return data

def hop2_neighbors(g,i):
    exist_nei = [i] + g.neighbors(i)
    add_nei = []
    for j in exist_nei:
        for h in g.neighbors(j):
            if h not in exist_nei:
                add_nei.append(h)
    return exist_nei + add_nei

def canonical_label2(data):
    # get the canonical label for code2
    num_node = data.x.shape[0]
    edge_index = data.edge_index
    edges = [(i,j) for i,j in zip(edge_index[0],edge_index[1])]
    g = igraph.Graph(directed=True)
    g.add_vertices(num_node)
    g.add_edges(edges)
    colors = []
    for i in range(num_node):
        colors.append(len(hop2_neighbors(g,i)))
    cl = g.canonical_permutation(color = colors)
    data.cl = torch.LongTensor(cl).to(data.x.device)
    data.num_vertex = num_node
    return data

#### NGNN utils

def create_subgraphs(data, h=1, sample_ratio=1.0, max_nodes_per_hop=None,
                     node_label='hop', use_rd=False, subgraph_pretransform=None):
    # Given a PyG graph data, extract an h-hop rooted subgraph for each of its
    # nodes, and combine these node-subgraphs into a new large disconnected graph
    # If given a list of h, will return multiple subgraphs for each node stored in
    # a dict.

    if type(h) == int:
        h = [h]
    assert (isinstance(data, Data))
    x, edge_index, num_nodes = data.x, data.edge_index, data.num_nodes

    new_data_multi_hop = {}
    for h_ in h:
        subgraphs = []
        for ind in range(num_nodes):
            nodes_, edge_index_, edge_mask_, z_ = k_hop_subgraph(
                ind, h_, edge_index, True, num_nodes, node_label=node_label,
                max_nodes_per_hop=max_nodes_per_hop
            )
            x_ = None
            edge_attr_ = None
            pos_ = None
            if x is not None:
                x_ = x[nodes_]
            else:
                x_ = None

            if 'node_type' in data:
                node_type_ = data.node_type[nodes_]

            if data.edge_attr is not None:
                edge_attr_ = data.edge_attr[edge_mask_]
            if data.pos is not None:
                pos_ = data.pos[nodes_]
            data_ = data.__class__(x_, edge_index_, edge_attr_, None, pos_, z=z_)
            data_.num_nodes = nodes_.shape[0]

            if 'node_type' in data:
                data_.node_type = node_type_

            if use_rd:
                # See "Link prediction in complex networks: A survey".
                adj = to_scipy_sparse_matrix(
                    edge_index_, num_nodes=nodes_.shape[0]
                ).tocsr()
                laplacian = ssp.csgraph.laplacian(adj).toarray()
                try:
                    L_inv = linalg.pinv(laplacian)
                except:
                    laplacian += 0.01 * np.eye(*laplacian.shape)
                lxx = L_inv[0, 0]
                lyy = L_inv[list(range(len(L_inv))), list(range(len(L_inv)))]
                lxy = L_inv[0, :]
                lyx = L_inv[:, 0]
                rd_to_x = torch.FloatTensor((lxx + lyy - lxy - lyx)).unsqueeze(1)
                data_.rd = rd_to_x

            if subgraph_pretransform is not None:  # for k-gnn
                data_ = subgraph_pretransform(data_)
                if 'assignment_index_2' in data_:
                    data_.batch_2 = torch.zeros(
                        data_.iso_type_2.shape[0], dtype=torch.long
                    )
                if 'assignment_index_3' in data_:
                    data_.batch_3 = torch.zeros(
                        data_.iso_type_3.shape[0], dtype=torch.long
                    )

            subgraphs.append(data_)

        # new_data is treated as a big disconnected graph of the batch of subgraphs
        new_data = Batch.from_data_list(subgraphs)
        new_data.num_nodes = sum(data_.num_nodes for data_ in subgraphs)
        new_data.num_subgraphs = len(subgraphs)

        new_data.original_edge_index = edge_index
        new_data.original_edge_attr = data.edge_attr
        new_data.original_pos = data.pos

        # rename batch, because batch will be used to store node_to_graph assignment
        new_data.node_to_subgraph = new_data.batch
        del new_data.batch
        if 'batch_2' in new_data:
            new_data.assignment2_to_subgraph = new_data.batch_2
            del new_data.batch_2
        if 'batch_3' in new_data:
            new_data.assignment3_to_subgraph = new_data.batch_3
            del new_data.batch_3

        # create a subgraph_to_graph assignment vector (all zero)
        new_data.subgraph_to_graph = torch.zeros(len(subgraphs), dtype=torch.long)

        # copy remaining graph attributes
        for k, v in data:
            if k not in ['x', 'edge_index', 'edge_attr', 'pos', 'num_nodes', 'batch',
                         'z', 'rd', 'node_type']:
                new_data[k] = v

        if len(h) == 1:
            return new_data
        else:
            new_data_multi_hop[h_] = new_data

    return new_data_multi_hop