# from collections import deque
import collections
import copy
import gc
import heapq
import math
import os
import random
import sys
import time
from glob import glob
from typing import Any, Dict, List, Tuple

import numpy as np
import torch
import torch.multiprocessing as mp

from matplotlib import pyplot as plt
from tqdm import tqdm

from .aig import _Learned_Node, Learned_AIG

sys.setrecursionlimit(1500)

negation = ["", "", "-"]


def report_parallel_results(queue, report_time: int = 5):
    start = time.time()
    old_qsize = 0
    new_qsize = 0
    while not queue.empty():
        old_qsize = new_qsize
        new_qsize = queue.qsize()
        print(
            "\rCurrent queue size: "
            + str(new_qsize)
            + ". Old queue size: "
            + str(old_qsize)
            + ". Elapsed time: "
            + str(round(time.time() - start, 2)),
            flush=True,
        )
        time.sleep(report_time)


def generate_optimal_aigs(
    num_inputs: int, extra_nodes: int, strategy: str = "BFS"
) -> None:
    if strategy == "DFS":
        generate_optimal_aigs_DFS(num_inputs, extra_nodes + num_inputs + 1)
    else:
        generate_optimal_aigs_BFS(num_inputs, extra_nodes)


def array_parallel_optimal_aigs_BFS(num_inputs: int, extra_nodes: int) -> None:
    init_aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=[torch.rand(2**num_inputs)]
    )
    aig_size_limit = len(init_aig._nodes) + extra_nodes
    tmp = []
    for i in range(len(init_aig._nodes)):
        tmp.append(int(array2str(init_aig[i].truth_table.numpy().astype(int)), 2))
    init_graph_tts = np.array(tmp, dtype=np.uint32)

    with mp.Manager() as manager:
        # memory = manager.dict()
        memory = manager.list()
        results = manager.dict()
        queue = mp.Queue()

        # the first slot [0] will be empty
        for i in range(extra_nodes + 1):
            memory.append(manager.list())

        n_processes = os.cpu_count() // 4
        print("Number of processes:", n_processes)
        for left_id in range(1, len(init_aig._nodes)):
            for right_id in range(left_id + 1, len(init_aig._nodes)):
                for left_edge_type in [1, -1]:
                    for right_edge_type in [1, -1]:
                        new_aig, new_graphs_tts = new_try_node(
                            init_aig,
                            left_id,
                            right_id,
                            left_edge_type,
                            right_edge_type,
                            init_graph_tts,
                            {},
                            memory,
                            results,
                        )
                        # Skip adding to the queue the next processes if they won't be processed
                        if new_aig != None and len(new_aig._nodes) < aig_size_limit:
                            queue.put((new_aig, new_graphs_tts))
        processes = []
        for rank in range(n_processes):
            p = mp.Process(
                target=new_parallel_BFS_helper,
                args=[queue, init_graph_tts, memory, results, aig_size_limit],
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()
        # with mp.Pool(processes=n_processes) as pool:
        #     pool.imap_unordered(parallel_BFS_helper, range(n_processes))
        resutls = dict(results)
        write_aigs(num_inputs, results)


def parallel_optimal_aigs_BFS(num_inputs: int, extra_nodes: int) -> None:
    init_graph_tts = {}
    init_aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=[torch.rand(2**num_inputs)]
    )
    aig_size_limit = len(init_aig._nodes) + extra_nodes

    for i in range(len(init_aig._nodes)):
        init_graph_tts[array2str(init_aig[i].truth_table.numpy().astype(int))] = 1

    with mp.Manager() as manager:
        memory = []
        queue = manager.Queue()
        result_queue = manager.Queue()

        for i in range(aig_size_limit + 1):
            memory.append(manager.dict())

        n_processes = os.cpu_count() // 2

        init_result = {}
        print("Number of processes:", n_processes)
        for left_id in range(1, len(init_aig._nodes)):
            for right_id in range(left_id + 1, len(init_aig._nodes)):
                for left_edge_type in [1, -1]:
                    for right_edge_type in [1, -1]:
                        new_aig, new_graphs_tts = try_node2(
                            init_aig,
                            left_id,
                            right_id,
                            left_edge_type,
                            right_edge_type,
                            init_graph_tts,
                            {},
                            memory,
                            init_result,
                        )

                        # Skip adding to the queue the next processes if they won't be processed
                        if new_aig != None and len(new_aig._nodes) < aig_size_limit:
                            queue.put((new_aig, new_graphs_tts))
        processes = []
        if len(init_result) != 0:
            result_queue.put(init_result)
        search_start = time.time()

        for rank in range(n_processes):
            p = mp.Process(
                target=parallel_BFS_helper,
                args=[queue, init_graph_tts, memory, result_queue, aig_size_limit],
            )
            p.start()
            processes.append(p)

        report_p = mp.Process(target=report_parallel_results, args=[queue])
        report_p.start()

        for p in processes:
            p.join()
        print("Search elapsed time:", time.time() - search_start)
        report_p.terminate()
        result_list = []
        while not result_queue.empty():
            result_list.append(result_queue.get())
        results = merge_dicts(result_list)

        write_aigs(num_inputs, results)
        for p in processes:
            p.close()


def merge_dicts(list_dicts: List[Dict[Any, Any]]) -> Dict[Any, List[Any]]:
    result = {}
    for d in list_dicts:
        for k in d.keys():
            if k not in result:
                result[k] = []
            result[k] += d[k]
    return result


def new_parallel_BFS_helper(
    queue: mp.Queue,
    init_graph_tts: np.ndarray,
    memory: dict[int, list[dict[str, 1]]],
    results: dict[str, list[dict[str, any]]],
    aig_size_limit: int,
) -> None:
    # def parallel_BFS_helper() -> None:
    torch.set_num_threads(2)
    while not queue.empty():
        aig, graph_tts = queue.get()
        for left_id in range(1, len(aig._nodes)):
            for right_id in range(left_id + 1, len(aig._nodes)):
                for left_edge_type in [1, -1]:
                    for right_edge_type in [1, -1]:

                        new_aig, new_graphs_tts = new_try_node(
                            aig,
                            left_id,
                            right_id,
                            left_edge_type,
                            right_edge_type,
                            init_graph_tts,
                            graph_tts,
                            memory,
                            results,
                        )

                        # Skip adding to the queue the next processes if they won't be processed
                        if new_aig != None and len(new_aig._nodes) < aig_size_limit:
                            queue.put((new_aig, new_graphs_tts))


def parallel_BFS_helper(
    queue: mp.Queue,
    init_graph_tts: dict[str, 1],
    memory: dict[int, list[dict[str, 1]]],
    results_queue: dict[str, list[dict[str, any]]],
    aig_size_limit: int,
) -> None:
    # torch.set_num_threads(2)
    results = {}
    while not queue.empty():
        aig, graph_tts = queue.get()
        for left_id in range(1, len(aig._nodes)):
            for right_id in range(left_id + 1, len(aig._nodes)):
                for left_edge_type in [1, -1]:
                    for right_edge_type in [1, -1]:

                        new_aig, new_graphs_tts = try_node2(
                            aig,
                            left_id,
                            right_id,
                            left_edge_type,
                            right_edge_type,
                            init_graph_tts,
                            graph_tts,
                            memory,
                            results,
                        )

                        # Skip adding to the queue the next processes if they won't be processed
                        if new_aig != None and len(new_aig._nodes) < aig_size_limit:
                            queue.put((new_aig, new_graphs_tts))
    # results = clean_up_results(results)
    results_queue.put(results)


def generate_optimal_aigs_BFS(num_inputs: int, extra_nodes: int) -> None:
    processes = deque()
    memory = {}
    results = {}
    init_graph_tts = {}
    init_aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=[torch.rand(2**num_inputs)]
    )

    for i in range(len(init_aig._nodes)):
        init_graph_tts[array2str(init_aig[i].truth_table.numpy().astype(int))] = 1

    for i in range(1, extra_nodes + 1):
        memory[len(init_aig._nodes) + i] = []

    processes.append((init_aig, {}))
    cur_level_processes = 1
    next_level_processes = 0
    for k in range(extra_nodes):
        print(
            "Added "
            + str(k)
            + " extra nodes so far. Will consider "
            + str(cur_level_processes)
            + " at this stage"
        )
        for l in range(cur_level_processes):
            aig, graph_tts = processes.popleft()
            for left_id in range(1, len(aig._nodes)):
                for right_id in range(left_id + 1, len(aig._nodes)):
                    for left_edge_type in [1, -1]:
                        for right_edge_type in [1, -1]:

                            new_aig, new_graphs_tts = try_node(
                                aig,
                                left_id,
                                right_id,
                                left_edge_type,
                                right_edge_type,
                                init_graph_tts,
                                graph_tts,
                                memory,
                                results,
                            )

                            # Skip adding to the queue the next processes if they won't be processed
                            if new_aig != None and k + 1 != extra_nodes:
                                next_level_processes += 1
                                processes.append((new_aig, new_graphs_tts))

        cur_level_processes = next_level_processes
        next_level_processes = 0
    write_aigs(num_inputs, results)


def new_try_node(
    aig: Learned_AIG,
    left_id: int,
    right_id: int,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: Dict[str, bool],
    graph_tts: Dict[str, bool],
    memory: Dict[int, List[Dict[str, bool]]],
    results: Dict[str, List[dict[str, Any]]],
) -> Tuple[Learned_AIG | None, Dict[str, bool] | None]:

    left = aig._nodes[left_id]
    right = aig._nodes[right_id]

    new_tt_hash, inv_new_tt_hash, new_name, store = new_action(
        aig,
        left,
        right,
        left_edge_type,
        right_edge_type,
        init_graph_tts,
        graph_tts,
        memory,
    )

    if new_tt_hash == None:
        return (None, None)
    # It's a new graph!
    copy_aig = copy.deepcopy(aig)
    copy_graph_tts = copy.deepcopy(graph_tts)
    copy_graph_tts[copy_aig.n_ands() - 1] = new_tt_hash
    n = copy_aig.create_and(left, right, left_edge_type, right_edge_type)
    copy_aig.set_name(n.node_id, new_name)

    # Add the new graph to memory
    memory[copy_aig.n_ands()].append(copy_graph_tts)

    # Add the solution to the results
    if store:
        store_results(copy_aig, n.node_id, new_tt_hash, inv_new_tt_hash, results)

    return copy_aig, copy_graph_tts


def try_node(
    aig: Learned_AIG,
    left_id: int,
    right_id: int,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: dict[str, 1],
    graph_tts: dict[str, 1],
    memory: dict[int, list[dict[str, 1]]],
    results: dict[str, list[dict[str, any]]],
) -> tuple[Learned_AIG | None, dict[str, 1] | None]:

    left = aig._nodes[left_id]
    right = aig._nodes[right_id]

    new_tt_hash, inv_new_tt_hash, new_name, store = action(
        aig,
        left,
        right,
        left_edge_type,
        right_edge_type,
        init_graph_tts,
        graph_tts,
        memory,
    )
    # new_graph_hash, new_tt_hash, inv_new_tt_hash, new_name, store = action(aig, left, right, left_edge_type, right_edge_type, init_graph_tts, graph_tts, memory)

    if new_tt_hash == None:
        return (None, None)
    # It's a new graph!
    copy_aig = copy.deepcopy(aig)
    copy_graph_tts = copy.deepcopy(graph_tts)
    # copy_graph_tts.add(new_tt_hash)
    copy_graph_tts[new_tt_hash] = 1
    n = copy_aig.create_and(left, right, left_edge_type, right_edge_type)
    copy_aig.set_name(n.node_id, new_name)

    # Add the new graph to memory
    memory[len(copy_aig._nodes)].append(copy_graph_tts)

    # Add the solution to the results
    if store:
        store_results(copy_aig, n.node_id, new_tt_hash, inv_new_tt_hash, results)

    return copy_aig, copy_graph_tts


def try_node2(
    aig: Learned_AIG,
    left_id: int,
    right_id: int,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: Dict[str, bool],
    graph_tts: Dict[str, bool],
    memory: List[Dict[str, bool]],
    results: Dict[str, List[Dict[str, Any]]],
) -> tuple[Learned_AIG | None, Dict[str, bool] | None]:

    left = aig._nodes[left_id]
    right = aig._nodes[right_id]

    (
        new_graph_hash,
        inv_new_graph_hash,
        new_tt_hash,
        inv_new_tt_hash,
        new_name,
        store,
    ) = action2(
        aig,
        left,
        right,
        left_edge_type,
        right_edge_type,
        init_graph_tts,
        graph_tts,
        memory,
    )

    if new_tt_hash == None:
        return (None, None)
    # It's a new graph!
    copy_aig = copy.deepcopy(aig)
    copy_graph_tts = copy.deepcopy(graph_tts)
    # copy_graph_tts.add(new_tt_hash)
    copy_graph_tts[new_tt_hash] = True
    n = copy_aig.create_and(left, right, left_edge_type, right_edge_type)
    copy_aig.set_name(n.node_id, new_name)

    # Add the new graph to memory
    memory[len(copy_aig._nodes)][new_graph_hash] = True

    # Add the solution to the results
    if store:
        if new_graph_hash != None:
            store_results2(copy_aig, n.node_id, new_tt_hash, 1, results)
        if inv_new_graph_hash != None:
            store_results2(copy_aig, n.node_id, inv_new_tt_hash, -1, results)

    return copy_aig, copy_graph_tts


def store_results(
    aig: Learned_AIG,
    node_id: int,
    new_tt_hash: str | int,
    inv_new_tt_hash: str | int,
    results: Dict[str | int, List[Dict[str, Any]]],
) -> None:
    if new_tt_hash not in results:
        results[new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": 1}]
    elif aig.n_ands() == results[new_tt_hash][0]["aig"].n_ands():
        results[new_tt_hash] += [{"aig": aig, "po": node_id, "edge_type": 1}]
    elif aig.n_ands() < results[new_tt_hash][0]["aig"].n_ands():
        results[new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": 1}]
        # print("2. key", new_tt_hash, "value", results[new_tt_hash])

    # if inv_new_tt_hash not in results:
    #     results[inv_new_tt_hash] = []
    if inv_new_tt_hash not in results:
        results[inv_new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": -1}]
    elif aig.n_ands() == results[new_tt_hash][0]["aig"].n_ands():
        results[inv_new_tt_hash] += [{"aig": aig, "po": node_id, "edge_type": -1}]
    elif aig.n_ands() < results[new_tt_hash][0]["aig"].n_ands():
        results[inv_new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": -1}]


def store_results2(
    aig: Learned_AIG,
    node_id: int,
    new_tt_hash: str | int,
    edge_type: int,
    results: Dict[str | int, List[Dict[str, Any]]],
) -> None:
    if new_tt_hash not in results:
        results[new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": edge_type}]
    else:
        best_aig_size = results[new_tt_hash][0]["aig"].n_ands()
        if aig.n_ands() == best_aig_size:
            results[new_tt_hash].append(
                {"aig": aig, "po": node_id, "edge_type": edge_type}
            )
            # print(results[new_tt_hash])
        elif aig.n_ands() < best_aig_size:
            results[new_tt_hash] = [{"aig": aig, "po": node_id, "edge_type": edge_type}]


def write_aigs(num_inputs: int, results: Dict[str, List[Dict[str, Any]]]) -> None:
    clean_results = clean_up_results(results)
    root_path = str(num_inputs) + "_inputs"
    if not os.path.isdir(root_path):
        os.mkdir(root_path)
    for tt in clean_results.keys():
        i = 0
        fold_name = tt
        if isinstance(fold_name, int):
            fold_name = f"{fold_name:0{2**num_inputs}b}"
        tt_path = root_path + "/" + fold_name
        if not os.path.isdir(tt_path):
            os.mkdir(tt_path)
        for data in clean_results[tt]:
            aig = data["aig"]
            aig.set_po_edge(data["po"], -1, data["edge_type"])
            aig[-1].calculate_truth_table(force=True)
            aig.write_aig(tt_path + "/" + str(i) + ".aig")
        i += 1


def generate_optimal_aigs_DFS(num_inputs: int, limit_nodes: int) -> None:
    memory = {}
    graph_tts = {}
    init_graph_tts = {}
    results = {}
    aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=[torch.rand(2**num_inputs)]
    )
    for i in range(len(aig._nodes)):
        init_graph_tts[array2str(aig[i].truth_table.numpy().astype(int))] = 1

    extra_nodes = limit_nodes - len(aig._nodes)
    for i in range(1, extra_nodes + 1):
        memory[len(aig._nodes) + i] = []

    generate_optimal_aigs_DFS_helper(
        aig, init_graph_tts, graph_tts, memory, limit_nodes, results
    )
    print(results)
    write_aigs(num_inputs, results)


def generate_optimal_aigs_DFS_helper(
    aig: Learned_AIG,
    init_graph_tts: Dict[str, bool],
    graph_tts: Dict[str, bool],
    memory: Dict[int, List[Dict[str, bool]]],
    limit_nodes: int,
    results: Dict[str, List[Dict[str, Any]]],
) -> None:

    if len(aig._nodes) == limit_nodes:
        return None

    for i in range(1, len(aig._nodes)):
        for j in range(i + 1, len(aig._nodes)):
            for left_edge_type in [1, -1]:
                for right_edge_type in [1, -1]:
                    left = aig._nodes[i]
                    right = aig._nodes[j]

                    new_tt_hash, inv_new_tt_hash, new_name, store = action(
                        aig,
                        left,
                        right,
                        left_edge_type,
                        right_edge_type,
                        init_graph_tts,
                        graph_tts,
                        memory,
                    )
                    if new_tt_hash == None:
                        continue

                    # It's a new graph!
                    copy_graph_tts = copy.deepcopy(graph_tts)
                    del graph_tts[new_tt_hash]
                    copy_aig = copy.deepcopy(aig)
                    n = copy_aig.create_and(
                        left, right, left_edge_type, right_edge_type
                    )
                    copy_aig.set_name(n.node_id, new_name)

                    # Add the new graph to memory
                    memory[len(copy_aig._nodes)].append(copy_graph_tts)

                    # Add the solution to the results
                    if store:
                        store_results(
                            copy_aig, n.node_id, new_tt_hash, inv_new_tt_hash, results
                        )

                    # Recur
                    if len(copy_aig._nodes) != limit_nodes:
                        generate_optimal_aigs_DFS_helper(
                            copy_aig,
                            init_graph_tts,
                            copy_graph_tts,
                            memory,
                            limit_nodes,
                            results,
                        )
    return None


def is_parent(parent_node: _Learned_Node, child_node: _Learned_Node):
    if not child_node.is_pi() and (
        child_node.left.node_id == parent_node.node_id
        or child_node.right.node_id == parent_node.node_id
    ):
        return True
    return False


def new_action(
    aig: Learned_AIG,
    left: _Learned_Node,
    right: _Learned_Node,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: np.ndarray,
    graph_tts: np.ndarray,
    memory: Dict[int, List[Dict[str, bool]]],
) -> tuple[str | None, str | None, str | None, bool | None]:

    # check if left is parent of right
    if is_parent(left, right) or is_parent(right, left):
        return (None, None, None, None)

    # Truth table of the potential new node
    new_tt = get_new_numpy_tt(left, right, left_edge_type, right_edge_type)
    new_tt_hash = int(array2str(new_tt.astype(int)), 2)
    inv_new_tt_hash = int(array2str((~new_tt).astype(int)), 2)

    # Check if the new node is equivalent to a PI
    if new_tt_hash in init_graph_tts or inv_new_tt_hash in init_graph_tts:
        return (None, None, None, None)

    # Skip if the potential new node already exists in the graph
    if new_tt_hash in graph_tts or inv_new_tt_hash in graph_tts:
        return (None, None, None, None)

    # Name of the new node
    new_name = (
        negation[left_edge_type]
        + aig.get_name(left)
        + negation[right_edge_type]
        + aig.get_name(right)
    )

    if left.is_pi() and aig.get_name(left) in aig.get_name(right):
        return (None, None, None, None)

    if right.is_pi() and aig.get_name(right) in aig.get_name(left):
        return (None, None, None, None)

    # Skip if the graph with the new node or the new inverted node already exists
    if len(memory[aig.n_ands() + 1]) != 0:
        mem_copy = np.vstack(memory[aig.n_ands() + 1])[:, : aig.n_ands() + 1]
        if new_graph_exists(mem_copy, graph_tts, new_tt_hash, inv_new_tt_hash):
            return (None, None, None, None)

    # The new node is made out of all the PIs so we should store the result
    store = True
    for i in range(1, aig.n_pis() + 1):
        if str(i) not in new_name:
            store = False
            break

    return (new_tt_hash, inv_new_tt_hash, new_name, store)


def action(
    aig: Learned_AIG,
    left: _Learned_Node,
    right: _Learned_Node,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: Dict[str, bool],
    graph_tts: Dict[str, bool],
    memory: Dict[int, List[Dict[str, bool]]],
) -> Tuple[str | None, str | None, str | None, bool | None]:

    # check if left is parent of right
    if is_parent(left, right) or is_parent(right, left):
        return (None, None, None, None)

    # Truth table of the potential new node
    new_tt = get_new_numpy_tt(left, right, left_edge_type, right_edge_type)
    new_tt_hash = array2str(new_tt.astype(int))
    inv_new_tt_hash = array2str((~new_tt).astype(int))

    # Check if the new node is equivalent to a PI
    if new_tt_hash in init_graph_tts or inv_new_tt_hash in init_graph_tts:
        return (None, None, None, None)

    # Skip if the potential new node already exists in the graph
    if new_tt_hash in graph_tts or inv_new_tt_hash in graph_tts:
        return (None, None, None, None)

    # Name of the new node
    new_name = (
        negation[left_edge_type]
        + aig.get_name(left)
        + negation[right_edge_type]
        + aig.get_name(right)
    )

    if left.is_pi() and aig.get_name(left) in aig.get_name(right):
        return (None, None, None, None)

    if right.is_pi() and aig.get_name(right) in aig.get_name(left):
        return (None, None, None, None)

    # Skip if the graph with the new node or the new inverted node already exists
    if graph_exists(
        copy.deepcopy(memory[len(aig._nodes) + 1]),
        graph_tts,
        new_tt_hash,
        inv_new_tt_hash,
    ):
        return (None, None, None, None)

    new_graph_hash = graph_hash_exists(
        copy.deepcopy(memory[len(aig._nodes) + 1]),
        graph_tts,
        aig.n_pis(),
        new_tt_hash,
        inv_new_tt_hash,
    )
    if new_graph_hash is None:
        return (None, None, None, None)

    # new_graph_tts = copy.deepcopy(graph_tts)
    # new_graph_tts[new_tt_hash] = 1
    # if graph_exists(memory[len(aig._nodes) + 1], new_graph_tts):
    #     return (None, None, None, None, None)

    # inv_new_graph_tts = copy.deepcopy(graph_tts)
    # inv_new_graph_tts[inv_new_tt_hash] = 1
    # if graph_exists(memory[len(aig._nodes) + 1], inv_new_graph_tts):
    #     return (None, None, None, None, None)

    # The new node is made out of all the PIs so we should store the result
    store = True
    for i in range(1, aig.n_pis() + 1):
        if str(i) not in new_name:
            store = False
            break

    return (new_tt_hash, inv_new_tt_hash, new_name, store)


def action2(
    aig: Learned_AIG,
    left: _Learned_Node,
    right: _Learned_Node,
    left_edge_type: int,
    right_edge_type: int,
    init_graph_tts: Dict[str, bool],
    graph_tts: Dict[str, bool],
    memory: Dict[int, List[Dict[str, bool]]],
) -> Tuple[str | None, str | None, str | None, str | None, str | None, bool | None]:

    # check if left is parent of right
    if is_parent(left, right) or is_parent(right, left):
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    # Truth table of the potential new node
    new_tt = get_new_numpy_tt(left, right, left_edge_type, right_edge_type)
    new_tt_hash = array2str(new_tt.astype(int))
    inv_new_tt_hash = array2str((~new_tt).astype(int))

    # Check if the new node is equivalent to a PI
    if new_tt_hash in init_graph_tts or inv_new_tt_hash in init_graph_tts:
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    # Skip if the potential new node already exists in the graph
    if new_tt_hash in graph_tts or inv_new_tt_hash in graph_tts:
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    # Name of the new node
    new_name = (
        negation[left_edge_type]
        + aig.get_name(left)
        + negation[right_edge_type]
        + aig.get_name(right)
    )

    # Avoid reconvergent paths with PIs
    if left.is_pi() and aig.get_name(left) in aig.get_name(right):
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    if right.is_pi() and aig.get_name(right) in aig.get_name(left):
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    # Skip if the graph with the new node or the new inverted node already exists
    new_graph_hash, inv_new_graph_hash = graph_hash_exists(
        memory[len(aig._nodes) + 1],
        graph_tts,
        aig.n_pis(),
        new_tt_hash,
        inv_new_tt_hash,
    )
    if new_graph_hash is None and inv_new_graph_hash is None:
        return (
            None,
            None,
            None,
            None,
            None,
            None,
        )

    # The new node is made out of all the PIs so we should store the result
    store = True
    for i in range(1, aig.n_pis() + 1):
        if str(i) not in new_name:
            store = False
            break

    return (
        new_graph_hash,
        inv_new_graph_hash,
        new_tt_hash,
        inv_new_tt_hash,
        new_name,
        store,
    )


def array2str(array: np.ndarray) -> int:
    return int(str(array)[1:-1].replace(" ", ""), 2)


def clean_up_results(
    results: dict[str, list[dict[str, Any]]]
) -> dict[str, list[dict[str, Any]]]:
    new_results = {}
    for tt in results.keys():
        min_size = float("inf")
        for data in results[tt]:
            if len(data["aig"]._nodes) < min_size:
                min_size = len(data["aig"]._nodes)
        new_results[tt] = []
        for data in results[tt]:
            if len(data["aig"]._nodes) == min_size:
                new_results[tt].append(data)
    return new_results


def get_new_numpy_tt(
    left: _Learned_Node, right: _Learned_Node, left_edge_type: int, right_edge_type: int
) -> np.ndarray:
    if left_edge_type == -1 and right_edge_type == -1:
        return (~left.truth_table & ~right.truth_table).numpy()
    elif left_edge_type == -1 and right_edge_type == 1:
        return (~left.truth_table & right.truth_table).numpy()
    elif left_edge_type == 1 and right_edge_type == -1:
        return (left.truth_table & ~right.truth_table).numpy()
    else:
        return (left.truth_table & right.truth_table).numpy()


def issubset(sub_array: np.ndarray, sup_array: np.ndarray):
    intersection = np.intersect1d(sub_array, sup_array)
    return intersection.size == sub_array.size


def new_graph_exists(
    memory: np.ndarray, graph: np.ndarray, new_tt_hash: int, inv_new_tt_hash: int
) -> bool:
    for existing_graph in memory:
        if (
            new_tt_hash in existing_graph or inv_new_tt_hash in existing_graph
        ) and issubset(graph, existing_graph):
            return True
    return False
    graph[-1] = inv_new_tt_hash
    for existing_graph in memory:
        if graph == existing_graph:
            return True
    graph[-1] = new_tt_hash
    for existing_graph in memory:
        if graph == existing_graph:
            return True
    return False


def graph_exists(
    memory: List[Dict[str, bool]],
    graph: Dict[str, bool],
    new_tt_hash: str,
    inv_new_tt_hash: str,
) -> bool:
    new_graph = set(graph.keys())
    for old_graph in memory:
        existing_graph = set(old_graph.keys())
        if (
            new_tt_hash in existing_graph or inv_new_tt_hash in existing_graph
        ) and new_graph.issubset(existing_graph):
            return True
    return False


def hash_graph(num_inputs, graph, new_tt):
    length = 2 ** (num_inputs - 2)
    new_graph = list(graph) + [new_tt]
    return "".join([f"{tt:0{length}x}" for tt in sorted(new_graph)])


def graph_hash_exists(
    memory: Dict[str, int],
    graph: Dict[int, bool],
    num_inputs: int,
    new_tt_hash: int,
    inv_new_tt_hash: int,
) -> tuple[str | None, str | None]:
    new_graph_hash = hash_graph(num_inputs, graph, new_tt_hash)
    inv_new_graph_hash = hash_graph(num_inputs, graph, inv_new_tt_hash)
    if new_graph_hash in memory:
        new_graph_hash = None
    if inv_new_graph_hash in memory:
        inv_new_graph_hash = None
    return new_graph_hash, inv_new_graph_hash


def create_cut_data(
    aig: str | Learned_AIG,
    num_inputs: int,
    path_prefix: str = ".",
    start: int = 0,
    end: int | None = None,
) -> None:
    fold_name = os.path.splitext(os.path.basename(aig))[0]
    if isinstance(aig, str):
        aig_name = aig
        aig = Learned_AIG.read_aig(aig)
    else:
        aig_name = "AIG"

    nodes = aig._nodes[start:end]
    text = aig_name + " " + str(start) + "-" + str(end)
    for i in tqdm(range(len(nodes)), desc=text):
        n = aig._nodes[start + i]
        if not n.is_pi() and n.level >= math.log2(num_inputs):
            cut = create_cut(n, num_inputs)
            if cut is not None:
                dir = f"{path_prefix}/{num_inputs}_inputs/{fold_name}/"
                os.makedirs(dir, exist_ok=True)
                cut.write_aig(dir + str(n.node_id) + ".aig")
                del cut
                gc.collect()


def create_multiple_cut_data(
    aig: str | Learned_AIG,
    num_inputs: int,
    path_prefix: str = ".",
    start: int = 0,
    end: int | None = None,
) -> None:
    fold_name = os.path.splitext(os.path.basename(aig))[0]
    if isinstance(aig, str):
        aig_name = aig
        aig = Learned_AIG.read_aig(aig)
    else:
        aig_name = "AIG"

    nodes = aig._nodes[start:end]
    text = aig_name + " " + str(start) + "-" + str(end)
    for i in tqdm(range(len(nodes)), desc=text):
        n = aig._nodes[start + i]
        if not n.is_pi() and n.level >= math.log2(num_inputs):
            cuts = create_multiple_cuts(n, num_inputs)
            for j, cut in enumerate(cuts):
                if cut is not None:
                    dir = f"{path_prefix}/{num_inputs}_inputs/{fold_name}/"
                    os.makedirs(dir, exist_ok=True)
                    cut.write_aig(dir + str(n.node_id) + "_" + str(j) + ".aig")
                    del cut
                    gc.collect()


def create_multiple_cuts(
    root_node: _Learned_Node, num_inputs: int
) -> List[Learned_AIG]:
    # Generate a cut that has N-1 leaves
    pis, leaves, visited_nodes = collect_cut_data(root_node, num_inputs - 1)
    cuts = []

    # Ensure cut creation did not fail
    if leaves is None:
        return cuts
    assert pis is not None
    assert visited_nodes is not None

    # Epxand each leaf node such that we can create multiple N-leaves cuts
    for leaf in leaves.keys():
        new_pis = copy.copy(pis)
        new_leaves = copy.copy(leaves)
        new_visited_nodes = copy.copy(visited_nodes)
        expand_node(leaf, new_visited_nodes, new_leaves, new_pis)

        if len(new_leaves) + len(new_pis) == num_inputs:
            cuts.append(
                construct_cut2(
                    root_node,
                    new_pis + list(new_leaves.keys()),
                    new_visited_nodes,
                    num_inputs,
                )
            )
    return cuts


def expand_node(
    expand_node: _Learned_Node,
    visited_nodes: Dict[_Learned_Node, List[_Learned_Node]],
    leaves: Dict[_Learned_Node, bool],
    pis: List[_Learned_Node],
) -> None:
    del leaves[expand_node]

    for node in [expand_node.left, expand_node.right]:
        assert node is not None
        if node not in visited_nodes:
            visited_nodes[node] = []
            if node.is_pi():
                pis.append(node)
            else:
                leaves[node] = True
        visited_nodes[node].append(expand_node)

        zero_expansion(visited_nodes, leaves, pis)


def collect_cut_data(
    root_node: _Learned_Node, num_inputs: int
) -> Tuple[
    List[_Learned_Node] | None,
    Dict[_Learned_Node, bool] | None,
    Dict[_Learned_Node, List[_Learned_Node]] | None,
]:
    visited_nodes = {}
    leaves = {root_node: True}
    pis = []
    while len(leaves) + len(pis) < num_inputs:
        if len(leaves) == 0:
            return (None, None, None)
        select_leaf = random.choice(
            list(leaves.keys())
        )  # select random leaf to expand?
        expand_node(select_leaf, visited_nodes, leaves, pis)

    return pis, leaves, visited_nodes


def zero_expansion(visited_nodes, leaves, pis):
    # zero node expansion
    queue = collections.deque(leaves)
    while len(queue) > 0:
        node = queue.popleft()
        left = node.left
        right = node.right
        assert left is not None
        assert right is not None
        # children = [left, right]
        if left in visited_nodes:
            del leaves[node]
            if right not in visited_nodes:
                visited_nodes[right] = []
                if right.is_pi():
                    pis.append(right)
                else:
                    leaves[right] = True
            visited_nodes[right].append(node)
            visited_nodes[left].append(node)
            queue = collections.deque(leaves)
        elif right in visited_nodes:
            del leaves[node]
            if left not in visited_nodes:
                visited_nodes[left] = []
                if left.is_pi():
                    pis.append(left)
                else:
                    leaves[left] = True
            visited_nodes[right].append(node)
            visited_nodes[left].append(node)
            queue = collections.deque(leaves)


def create_cut(root_node: _Learned_Node, num_inputs: int) -> Learned_AIG | None:
    visited_nodes = {}
    deque = collections.deque()
    leaves = []

    left = root_node.left
    right = root_node.right

    visited_nodes[left] = [root_node]
    visited_nodes[right] = [root_node]

    if left.is_pi():
        leaves.append(left)
    else:
        deque.append(left)

    if right.is_pi():
        leaves.append(right)
    else:
        deque.append(right)

    while len(deque) > 0:
        cur_node = deque.popleft()
        left = cur_node.left
        if left not in visited_nodes:
            if left.is_pi():
                leaves.append(left)
            else:
                deque.append(left)
            visited_nodes[left] = []
        visited_nodes[left].append(cur_node)
        right = cur_node.right
        if right not in visited_nodes:
            if right.is_pi():
                leaves.append(right)
            else:
                deque.append(right)
            visited_nodes[right] = []
        visited_nodes[right].append(cur_node)

        if len(deque) + len(leaves) == num_inputs:
            found_leaves = True
            for i in range(len(deque)):
                node = deque[i]
                if node.left in visited_nodes or node.right in visited_nodes:
                    found_leaves = False
                    deque[i], deque[0] = deque[0], deque[i]
                    break
            if found_leaves:
                leaves += list(deque)

    if len(leaves) < num_inputs:
        return None

    aig = constrct_cut(root_node, leaves, visited_nodes, num_inputs)

    return aig


def constrct_cut(
    root_node: _Learned_Node, leaves, visited_nodes, num_inputs: int
) -> Learned_AIG:

    aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=None, skip_truth_tables=True
    )
    name_map = dict(zip(leaves, range(1, num_inputs + 1)))
    heap = []
    for node in leaves:
        heap.extend([(child.level, child) for child in visited_nodes[node]])
    heapq.heapify(heap)
    while len(heap) > 0:
        node = heapq.heappop(heap)[1]
        if node not in name_map:
            # if name_map[node] > num_inputs:
            if node.left in name_map and node.right in name_map:
                new_aig_node = aig.create_and(
                    name_map[node.left],
                    name_map[node.right],
                    node.left_edge_type,
                    node.right_edge_type,
                )
                name_map[node] = new_aig_node.node_id
                if node != root_node:
                    for child in visited_nodes[node]:
                        heapq.heappush(heap, (child.level, child))

    aig.set_po_edge(name_map[root_node], -1, 1)

    return aig


def construct_cut2(
    root_node: _Learned_Node, leaves, visited_nodes, num_inputs: int
) -> Learned_AIG:
    aig = Learned_AIG(
        n_pis=num_inputs, n_pos=1, truth_tables=None, skip_truth_tables=True
    )
    name_map = dict(zip(leaves, range(1, num_inputs + 1)))
    heap = [(node.level, node) for node in visited_nodes.keys()]
    heapq.heapify(heap)
    while len(heap) > 0:
        lvl, node = heapq.heappop(heap)
        if node not in name_map:
            new_aig_node = aig.create_and(
                name_map[node.left],
                name_map[node.right],
                node.left_edge_type,
                node.right_edge_type,
            )
            name_map[node] = new_aig_node.node_id
    root = aig.create_and(
        name_map[root_node.left],
        name_map[root_node.right],
        root_node.left_edge_type,
        root_node.right_edge_type,
    )
    aig.set_po_edge(root, -1, 1)

    return aig


def get_full_cut(root_node: _Learned_Node) -> Learned_AIG:
    deque = collections.deque()
    deque.append(root_node)
    name_map = {}
    pi_id = 1
    i = 0
    while i < len(deque):
        node = deque[i]
        if node not in name_map:
            name_map[node] = 0
        if node.left not in name_map:
            if node.left.is_pi():
                name_map[node.left] = pi_id
                pi_id += 1
            else:
                name_map[node.left] = 0
                deque.append(node.left)
        if node.right not in name_map:
            if node.right.is_pi():
                name_map[node.right] = pi_id
                pi_id += 1
            else:
                name_map[node.right] = 0
                deque.append(node.right)
        i += 1

    cut = Learned_AIG(pi_id - 1, 1, None)
    for i in range(len(deque)):
        node = deque[-(i + 1)]
        if not node.is_pi():
            new_node = cut.create_and(
                name_map[node.left],
                name_map[node.right],
                node.left_edge_type,
                node.right_edge_type,
            )
            name_map[node] = new_node.node_id
    cut.set_po_edge(name_map[root_node], -1, 1)
    return cut


def parallel_create_multiple_cut_data_helper(
    queue: Any, num_inputs: int, path_prefix: str = "."
) -> None:
    while not queue.empty():
        f, start, end = queue.get()
        create_multiple_cut_data(f, num_inputs, path_prefix, start, end)


def parallel_create_cut_data_helper(
    queue: Any, num_inputs: int, path_prefix: str = "."
) -> None:
    while not queue.empty():
        f, start, end = queue.get()
        create_cut_data(f, num_inputs, path_prefix, start, end)


def parallel_create_multiple_cut_data(
    file_paths: List[str],
    num_inputs: int,
    path_prefix: str = ".",
    num_workers: int = 2,
    chunk_size: int | None = None,
) -> None:
    with mp.Manager() as manager:
        queue = manager.Queue()
        for f in file_paths:
            if chunk_size is not None:
                num_nodes = int(open(f, "rb").readline().split()[1])
                for start in range(0, num_nodes, chunk_size):
                    queue.put((f, start, start + chunk_size))
            else:
                queue.put((f, 0, None))

        processes = []
        for rank in range(num_workers):
            p = mp.Process(
                target=parallel_create_multiple_cut_data_helper,
                args=[queue, num_inputs, path_prefix],
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()


def parallel_create_cut_data(
    file_paths: List[str],
    num_inputs: int,
    path_prefix: str = ".",
    num_workers: int = 2,
    chunk_size: int | None = None,
) -> None:
    with mp.Manager() as manager:
        queue = manager.Queue()
        for f in file_paths:
            if chunk_size is not None:
                num_nodes = int(open(f, "rb").readline().split()[1])
                for start in range(0, num_nodes, chunk_size):
                    queue.put((f, start, start + chunk_size))
            else:
                queue.put((f, 0, None))

        processes = []
        for rank in range(num_workers):
            p = mp.Process(
                target=parallel_create_cut_data_helper,
                args=[queue, num_inputs, path_prefix],
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

        # list(pool.map(functools.partial(create_cut_data, num_inputs=num_inputs, path_prefix=path_prefix), file_paths))
        # pool.imap_unordered(create_cut_data, file_paths, chunksize=1)


def parallel_optimize_cut_data(data_path: str, abc_path: str, num_workers: int) -> None:
    file_paths = glob(data_path, recursive=True)
    file_list = []
    chunk_size = len(file_paths) // num_workers + 1
    for i in range(0, len(file_paths), chunk_size):
        file_list.append(file_paths[i : i + chunk_size])
    abc_paths = [abc_path] * num_workers
    p_ids = list(range(num_workers))

    with mp.Pool(num_workers) as pool:
        pool.starmap(optimize_cut_data_helper, zip(file_list, abc_paths, p_ids))


def optimize_cut_data_helper(file_paths: List[str], abc_path: str, pos: int) -> None:
    resyn2 = "balance; rewrite; refactor; balance; rewrite; rewrite -z; balance; refactor -z; rewrite -z; balance"
    for unoptimized_aig in tqdm(file_paths, position=pos, leave=True):
        optimized_aig = unoptimized_aig.replace("unoptimized", "optimized")
        os.makedirs(os.path.dirname(optimized_aig), exist_ok=True)
        os.system(
            f'./{abc_path} -c "read {unoptimized_aig}; {resyn2}; write {optimized_aig}" >/dev/null 2>&1'
        )


def identify_large_cuts(
    file_queue: mp.Queue, tasks_queue: mp.Queue, min_inputs: int, max_inputs: int
) -> None:
    while not file_queue.empty():
        aig_file = file_queue.get()
        aig = Learned_AIG.read_aig(aig_file)
        quit = False
        node_sets = []
        cur_level = 0
        while not quit:
            quit = True
            node_sets = construct_node_set_level(cur_level, aig)
            for i in range(len(node_sets)):
                if len(node_sets[i]) >= min_inputs and len(node_sets[i]) <= max_inputs:
                    quit = False
                    tasks_queue.put((aig_file, i, cur_level))
            cur_level += 1


def construct_node_set_level(leaf_level: int, aig: Learned_AIG) -> List[set]:
    node_sets = []
    for i in tqdm(range(len(aig._nodes))):
        n = aig._nodes[i]
        if n.level < leaf_level:
            node_sets.append([])
        elif n.level == leaf_level:
            node_sets.append(set([n.node_id]))
        else:
            if (
                len(node_sets[n.left.node_id]) == 0
                or len(node_sets[n.right.node_id]) == 0
            ):
                node_sets.append([])
            else:
                node_sets.append(
                    node_sets[n.left.node_id].union(node_sets[n.right.node_id])
                )
    return node_sets


def extract_large_cuts(
    tasks_queue: mp.Queue,
    path_prefix: str,
) -> None:
    cur_file = ""
    while not tasks_queue.empty():
        aig_file, root_id, leaf_level = tasks_queue.get()
        if cur_file != aig_file:
            aig = Learned_AIG.read_aig(aig_file)
            cur_file = aig_file

        root = aig._nodes[root_id]
        node_queue = collections.deque([root])
        visited_nodes = {}
        leaves = []
        while len(node_queue) > 0:
            cur_node = node_queue.popleft()
            for node in [cur_node.left, cur_node.right]:
                assert node is not None
                if node not in visited_nodes:
                    visited_nodes[node] = []
                    if node.level == leaf_level:
                        leaves.append(node)
                    else:
                        node_queue.append(node)
                visited_nodes[node].append(cur_node)
        cut = construct_cut2(root, leaves, visited_nodes, len(leaves))

        dir_name = os.path.splitext(os.path.basename(aig_file))[0]
        new_dir = f"{path_prefix}/{cut.n_pis()}_inputs/{dir_name}/"
        os.makedirs(new_dir, exist_ok=True)
        cut.write_aig(new_dir + str(root_id) + ".aig")
        del cut


def parallel_extract_large_cuts(
    aig_files: List[str],
    min_inputs: int,
    max_inputs: int,
    num_workers: int,
    path_prefix: str = ".",
) -> None:
    with mp.Manager() as manager:
        file_queue = manager.Queue()
        tasks_queue = manager.Queue()
        for aig_file in aig_files:
            file_queue.put(aig_file)

        args = [
            (file_queue, tasks_queue, min_inputs, max_inputs)
            for _ in range(num_workers)
        ]
        with mp.Pool(num_workers) as pool:
            pool.starmap(identify_large_cuts, args)

        args = [(tasks_queue, path_prefix) for _ in range(num_workers)]
        with mp.Pool(num_workers) as pool:
            pool.starmap(extract_large_cuts, args)


if __name__ == "__main__":

    t1 = time.time()
    # parallel_optimal_aigs_BFS(4, 4)
    # create_cut_data("./data/ctrl.aig", 5, "./data/unoptimized")
    # for f in glob('data/EPFL/**/*.aig', recursive=True):
    #     create_cut_data(f, 6, "./data/unoptimized")
    # parallel_create_cut_data(glob('data/EPFL/**/*.aig', recursive=True), 7, "./data/unoptimized", 4)
    # create_multiple_cut_data("data/EPFL/arithmetic/adder.aig", 4, "./data/unoptimized")
    # parallel_optimize_cut_data('data/test/**/*.aig', 'src/data/abc/build/abc', 2)
    # print(glob('data/EPFL/**/*.aig', recursive=True))
    # parallel_extract_large_cuts(3, 3, 4)
    t2 = time.time()
    # new_parallel_optimal_aigs_BFS(3, 3)
    # t3 = time.time()
    # generate_optimal_aigs(3, 3)
    # t4 = time.time()
    # aig = Learned_AIG.read_aig("./data/ctrl.aig")
    # c = get_full_cut(aig[15])
    # c.write_aig("./data/test.aig")
    print("Parallel time:", t2 - t1)
    # print("New parallel time:", t3-t2)
    # print("Simple time:", t4-t3)
    # print(timeit.timeit(lambda: parallel_optimal_aigs_BFS(3, 3), number=10))

    # cProfile.run("parallel_optimal_aigs_BFS(3, 3)", "my_func_stats")
    # p = pstats.Stats("my_func_stats")
    # p.sort_stats("cumulative").print_stats()
