import time, glob, pickle, argparse, os
import numpy as np
from .pyspice_utils import simulation
from .dataset_utils import reverse_feedback_loops, split_collection, split_graph, ocb_graph_with_split_gm

NODE_TYPE_OCB = {'R': 0,'C': 1,'+gm+':2,'-gm+':3,'+gm-':4,'-gm-':5, 'sudo_in':6, 'sudo_out':7, 'In': 8, 'Out':9}

def bfs(in_node_idx, out_node_idx, edge_index):
    '''
    Turns an undirected input circuit graph into a DAG. Current flows from input to output nodes.
    '''
    queue = [o for (i, o) in edge_index if i == in_node_idx]
    visited = [in_node_idx]
    # Destination nodes
    while len(queue) > 0:
        next_idx = queue.pop(0)
        if (next_idx != out_node_idx) and (not next_idx in visited):
            visited, queue = bfs_step(next_idx, visited, queue, edge_index)
    return visited + [out_node_idx]


def bfs_step(node_idx, visited, queue, edge_index):
    child_nodes = [o for (i, o) in edge_index if i == node_idx]
    child_nodes = np.array([c for c in child_nodes if c not in visited])
    np.random.shuffle(child_nodes)
    return visited + [node_idx], queue + child_nodes.tolist()

def remove_single_connection_nodes(graph):
    nodes_to_remove = []
    for v in graph.vs:
        node_type = v['type']
        if node_type in [NODE_TYPE_OCB['In'], NODE_TYPE_OCB['Out']]:
            continue
        if graph.degree(v.index) == 2: # 2 since is undirected ( 1 node connected to 2 edges)
            nodes_to_remove.append(v.index)
    graph.delete_vertices(nodes_to_remove)
    return graph

def are_any_single_connection_nodes(graph):
    for v in graph.vs:
        node_type = v['type']
        if node_type in [NODE_TYPE_OCB['In'], NODE_TYPE_OCB['Out']]:
            continue
        if graph.degree(v.index) == 2: # 2 since is undirected ( 1 node connected to 2 edges)
            return True
    return False


def to_dag(graph, undirected=True):

    # Find root node
    in_node_idx = int(np.where(np.array(graph.vs['type']) == NODE_TYPE_OCB['In'])[0][0])
    out_node_idx = int(np.where(np.array(graph.vs['type']) == NODE_TYPE_OCB['Out'])[0][0])
    
    # Edge index
    edge_index = graph.get_edgelist()
    if not undirected:
        edge_index += [(o, i) for (i, o) in edge_index]
    # Turn graph into DAG
    node_ordering = bfs(in_node_idx, out_node_idx, edge_index)

    # If the graph is disconnected, proceed to next graph right away
    assert len(node_ordering) == len(graph.vs['type'])

    # Finally update edge list
    inverse_perm = np.zeros(len(node_ordering), dtype=int)
    inverse_perm[node_ordering] = np.arange(len(node_ordering))
    new_edges = [(i, o) for (i, o) in edge_index if inverse_perm[i] < inverse_perm[o]]
    graph.delete_edges()
    graph.add_edges(new_edges)

    return graph


if __name__ == "__main__":

    ## Load cmd line args
    parser = argparse.ArgumentParser(description='Model Evaluation')

    parser.add_argument('--out_path', dest='output_path', type=str, required=True,
                        help='The path to the model output directory.')
    parser.add_argument('--undir', dest='undirected', action='store_true', default=False,
                        help='Whether the graph needs to be directed before simulation')
    parser.add_argument('--pins', dest='pins', action='store_true', default=False,
                        help='Whether graph are built considering input and output vccs pins')
    parser.add_argument('--prepro', dest='preprocess', action='store_true', default=False,
                        help='Whether to apply the following preprocessing steps: reverse feeback loop edges,' \
                        'split subcircuits into individual components, add electric nodes and gm pins')
    
    args = parser.parse_args()

    success, fail = 0, 0
    t = time.time()

    for i, out_graph_p in enumerate(glob.glob(args.output_path + '/*')):
        with open(out_graph_p, 'rb') as f:
            out_graph = pickle.load(f)

        try:
            if args.preprocess:
                # Search feedback loops & break subcircuits into elements
                directed_feedback_graph = reverse_feedback_loops([out_graph])
                # Add intermediates nodes between components
                split_graph = split_collection(directed_feedback_graph, False)
                # Further split gms into three nodes with explicit i/o pins
                out_graph = ocb_graph_with_split_gm(split_graph[0])
            if args.undirected:
                # Sanity check
                out_graph = to_dag(out_graph.copy(), args.undirected)
        except:
            fail += 1
            continue
        try:
            sim = simulation(out_graph, 'default', args.pins, compute_features=False)
            success += 1
            del sim
        except:
            fail += 1
        if i % 500 == 0:
            print(f'Iter {i}, time {time.time() - t}')
    
    valid = np.round(success / (success + fail), 3)
    print(f'~~ Simulation done, {100 * valid} % of valid circuits. ~~~')
    with open(os.path.join(args.output_path, '..', 'results.txt'), 'w') as f:
        f.write(f"Validity %: {valid}")