import numpy
import torch
from torch_geometric.data import Data

class OrderedData(Data):
    def __init__(self, edge_index=None, x=None, forward_level=None, forward_index=None, backward_level=None, backward_index=None):
        super().__init__()
        self.edge_index = edge_index
        self.x = x

        # if forward_level == None or forward_index == None or backward_level == None or backward_index == None:
        #     forward_level, forward_index, backward_level, backward_index = return_order_info(edge_index, x.size(0))
        self.forward_level = forward_level
        self.forward_index = forward_index
        self.backward_level = backward_level
        self.backward_index = backward_index
    
    def __inc__(self, key, value, *args, **kwargs):
        if 'index' in key or 'face' in key:
            return self.num_nodes
        else:
            return 0

    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == 'forward_index' or key == 'backward_index':
            return 0
        elif key == "edge_index":
            return 1
        else:
            return 0


def return_order_info(edge_index, num_nodes):
    # NOTE: here we just reverse the logic levels, instead of re-calculating the logic levels again.
    # NOTE: ignore the previous note.
    ns = torch.LongTensor([i for i in range(num_nodes)])
    forward_level = top_sort(edge_index, num_nodes)
    ei2 = torch.LongTensor([list(edge_index[1]), list(edge_index[0])])
    backward_level = top_sort(ei2, num_nodes)
    forward_index = ns
    backward_index = ns
    # backward_index = torch.LongTensor([i for i in range(num_nodes)]) # NOTE: wrong before.
    
    return forward_level, forward_index, backward_level, backward_index


# see https://github.com/unbounce/pytorch-tree-lstm/blob/66f29a44e98c7332661b57d22501107bcb193f90/treelstm/util.py#L8
# assume nodes consecutively named starting at 0
#
def top_sort(edge_index, graph_size):

    node_ids = numpy.arange(graph_size, dtype=int)

    node_order = numpy.zeros(graph_size, dtype=int)
    unevaluated_nodes = numpy.ones(graph_size, dtype=bool)

    parent_nodes = edge_index[0]
    child_nodes = edge_index[1]

    n = 0
    while unevaluated_nodes.any():
        # Find which parent nodes have not been evaluated
        unevaluated_mask = unevaluated_nodes[parent_nodes]

        # Find the child nodes of unevaluated parents
        unready_children = child_nodes[unevaluated_mask]

        # Mark nodes that have not yet been evaluated
        # and which are not in the list of children with unevaluated parent nodes
        nodes_to_evaluate = unevaluated_nodes & ~numpy.isin(node_ids, unready_children)

        node_order[nodes_to_evaluate] = n
        unevaluated_nodes[nodes_to_evaluate] = False

        n += 1

    return torch.from_numpy(node_order).long()