import time
import torch
import networkx as nx
from multiprocessing import Pool
from utils.utils import get_dataset, get_folder_from_dset


def compute_digraph_from_dataset(
    dset: str = "cora_ml",
    path: str = "./",
):
    dataset = get_dataset(name=dset, root_dir=path)

    G = nx.DiGraph()
    G.add_edges_from(dataset.edge_index.T.numpy())

    # Remove self-loops to compute the lifting
    G.remove_edges_from(nx.selfloop_edges(G))
    assert not any(
        u == v for u, v in G.edges
    ), "Graph contains self-loops after removal!"

    return G


def find_simplex_rec(possible_next_vertices, prefix, all_simplices, max_dim, G):
    """
    Recursive function to visit the graph and find cliques (i.e., simplices).
    Parameters:
        possible_next_vertices: list of possible next vertices (initial vertices at the beginning)
        prefix: list of vertices visited so far in recursion (empty at the beginning)
        all_simplices: list to store all simplices found so far (empty at the beginning)
                        it contains max_size sublists to store simplices based on their size
        max_dim: maximum size of interest for simplices
        G: networkx.Graph object containing the graph
    """

    if prefix:
        all_simplices[len(prefix) - 1].append(prefix)
    else:  # recursion head, need new list
        all_simplices = [[] for _ in range(max_dim)]

    if len(prefix) == max_dim:
        return

    for v in possible_next_vertices:
        current_prefix = prefix + [v]
        new_possible_vertices = (
            list(set(possible_next_vertices).intersection(G.neighbors(v)))
            if prefix
            else list(G.neighbors(v))
        )
        find_simplex_rec(
            new_possible_vertices, current_prefix, all_simplices, max_dim, G
        )

    return all_simplices if not prefix else None


def compute_lifting(G, max_dim=2, num_threads=4):
    batch_simplices = [[] for _ in range(max_dim + 1)]
    nodes_per_thread = len(G.nodes) // num_threads + 1
    tasks = [
        (
            list(G.nodes)[i * nodes_per_thread : (i + 1) * nodes_per_thread],
            [],
            [],
            max_dim + 1,
            G,
        )
        for i in range(num_threads)
    ]
    start_time = time.time()
    with Pool() as pool:
        results = pool.starmap(find_simplex_rec, tasks)
    print("DFC Lifting Time:", time.time() - start_time)

    for thread_results in results:
        for dim, simplices in enumerate(thread_results):
            batch_simplices[dim].extend(simplices)

    return batch_simplices


def compute_and_save_lifting(
    max_dim=2, num_threads=8, dset="cora_ml", path="./", flagser=False
):
    print(f"*** Computing the lifting up to order {max_dim} ***")
    G = compute_digraph_from_dataset(dset, path)
    simplices = compute_lifting(G, max_dim, num_threads)
    torch.save(
        simplices,
        f"{path}/{get_folder_from_dset(dset)}/{get_folder_from_dset(dset)}_all_simplices",
    )

    return simplices


if __name__ == "__main__":
    pass
