from __future__ import absolute_import
from __future__ import print_function
import sys
import os, time, json, re, pickle
from optparse import OptionParser
import networkx as nx
from multiprocessing import Pool
# the next line can be removed after installation
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import pyverilog
from pyverilog.vparser.parser import parse
from AST_analyzer import *
from AST_analyzer import AST_analyzer
# from logicTree import *


design_json = "/home/coguest5/LS-benchmark/design.json"

def BFSBackTrace(G, node_dict, start_node):
    print(start_node)
    # Set of visited nodes
    V = set([start_node])
    # Queue with start node
    Q = [start_node]
    
    # While the queue is not empty
    while Q:
        # Current node
        u = Q.pop(0)
        # For each neighbor of the current node
        for v in G.successors(u):
            # If the neighbor is a register-typed node
            if re.sub(r'\.PTR(\d+)$', '', v) != start_node:
                if node_dict[v].type in ['Reg', 'Input']:
                    # Found register-typed node
                    V.add(v)
                    print('end: ', v)
                    continue
                elif node_dict[v].type in ['Pointer', 'Partselect']:
                    if node_dict[node_dict[v].father].type in ['Reg', 'Input']:
                        V.add(v)
                        print('end: ', v)
                        continue

            if v not in V:
                # Add unvisited node to queue
                Q.append(v)
                # Mark node as visited
                V.add(v)
    return V

def run_all_bit_one_ep(para_lst):
    ep_lst, g_nx, node_dict, ep, subgraph_save_dir = para_lst[0], para_lst[1], para_lst[2], para_lst[3], para_lst[4]

    trace_set_all = set()

# for ep in ep_lst:
    trace_set = BFSBackTrace(g_nx, node_dict, ep)
    trace_set_all = trace_set_all.union(trace_set)

    sub_graph = g_nx.subgraph(list(trace_set_all))

    sub_graph = nx.DiGraph(sub_graph)

    with open (f"/home/coguest5/hdl_fusion/data_collect/dataset/ori/graph/{design}/{ep}_ast.pkl", 'rb') as f:
        sub_graph0 = pickle.load(f)
    sub_graph0 = nx.DiGraph(sub_graph0)
    print(sub_graph0)
    print(sub_graph)
    print(sub_graph.nodes())

    with open (f"{subgraph_save_dir}/{ep}.pkl", 'wb') as f:
        pickle.dump(sub_graph, f, pickle.HIGHEST_PROTOCOL)


    

def run_one_design(design_name):
    print('Current Design:', design_name)

    save_dir = f"/home/coguest5/hdl_fusion/data_collect/dataset/rtl_graph/{cmd}/"
    with open (f"{save_dir}/{design_name}_{cmd}.pkl", 'rb') as f:
        g = pickle.load(f)
    with open (f"{save_dir}/{design_name}_{cmd}_node_dict.pkl", 'rb') as f:
        node_dict = pickle.load(f)
    
    g_nx = nx.DiGraph(g)
    print(g_nx)


    with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design_name}.json", 'r') as f:
        ep_lst = json.load(f)

    subgraph_save_dir = f"../cone_graph/{cmd}/{design_name}"

    if not os.path.exists(subgraph_save_dir):
        os.mkdir(subgraph_save_dir)
    


    with open (f"{subgraph_save_dir}/{design_name}_node_dict.pkl", 'wb') as f:
        pickle.dump(node_dict, f)
    

    para_lst = []
    print('# ep: ', len(ep_lst))
    
    for ep in ep_lst:
        if ep != 'addr':
            continue
        para = [ep_lst, g_nx, node_dict, ep, subgraph_save_dir]
        para_lst.append(para)

        run_all_bit_one_ep(para)


    # with Pool(50) as p:
    #     p.map(run_all_bit_one_ep, para_lst)
    #     p.close()
    #     p.join()

    





if __name__ == '__main__':
    global cmd
    cmd = "ori"
    # cmd = "pos1"
    design_js_path = "/home/coguest5/hdl_fusion/data_collect/dataset/json/design_lst/design_lst.json"
    with open (design_js_path, 'r') as f:
        design_lst = json.load(f)
    
    design_lst = ['b14']
    for design in design_lst:
        run_one_design(design)