import torch

from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.data.data import Data
from torch_geometric.utils import coalesce, to_dense_adj, dense_to_sparse


def get_dists(start_node: int, edge_index: torch.Tensor):
    row = edge_index[0]
    col = edge_index[1]
    num_nodes = int(edge_index.max() + 1)
    dists = torch.full((num_nodes,), -1, dtype=torch.int32)
    dists[start_node] = 0
    new_nodes = torch.tensor([0])
    node_mask = torch.zeros(num_nodes, dtype=torch.bool)
    i = 0
    while dists.min() == -1:
        i += 1
        node_mask[new_nodes] = True
        indices = torch.index_select(node_mask, 0, row)
        node_mask[new_nodes] = False
        new_nodes = col[indices]
        new_nodes = new_nodes[dists[new_nodes] == -1]
        if new_nodes.shape[0] == 0:
            break
        dists[new_nodes] = i
    return dists


def to_hetero(graph: Data) -> HeteroData:
    new_graph = HeteroData()

    for attr in [
        "pos",
        "object_ids",
        "collider_ids",
        "deformable_ids",
        "node_type",
    ]:
        setattr(new_graph, attr, getattr(graph, attr))

    new_graph["x"].x = graph.x
    new_graph["x", "level0", "x"].edge_index = graph.edge_index
    new_graph["x", "level0", "x"].edge_attr = graph.edge_attr
    return new_graph


def get_reduced_edges(edge_index: torch.Tensor, start_node: int = -1):
    if start_node == -1:
        start_node = edge_index.min()
    unique_nodes = torch.unique(edge_index)
    present_nodes = torch.full((int(unique_nodes.max() + 1),), 0)
    present_nodes[unique_nodes] = 1
    reindex = present_nodes.cumsum(0) - 1
    edge_index = reindex[edge_index]

    dists = get_dists(start_node, edge_index)
    marked = dists % 2 == 0

    row = edge_index[0]
    column = edge_index[1]
    from_marked = edge_index[:, marked[row] & ~marked[column]]
    to_marked = edge_index[:, ~marked[row] & marked[column]]
    from_adj = to_dense_adj(from_marked)[0]
    to_adj = to_dense_adj(to_marked)[0]

    from_adj = from_adj[marked][:, ~marked]
    to_adj = to_adj[~marked][:, marked]

    adj = from_adj @ to_adj
    new_edge_index, _ = dense_to_sparse(adj)
    marked_positions = torch.arange(len(marked))[marked]
    new_edge_index = new_edge_index[
        :, new_edge_index[0] != new_edge_index[1]
    ]  # remove self-edges
    new_edge_index = coalesce(new_edge_index)  # remove duplicate edges
    new_edge_index = marked_positions[new_edge_index]  # add positions of marked nodes
    new_edge_index = unique_nodes[new_edge_index]  # add positions of edges
    return new_edge_index


def build_hierarchical_graph(graph: Data, num_levels: int) -> HeteroData:
    new_graph = to_hetero(graph)
    last_edges = graph.edge_index
    for level in range(num_levels):
        new_edges = get_reduced_edges(last_edges)
        new_graph["x", f"level{level + 1}", "x"].edge_index = new_edges
        last_edges = new_edges

    return new_graph
