from __future__ import absolute_import
from __future__ import print_function
import sys
import os, time, json, re, pickle
from optparse import OptionParser
import networkx as nx
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import pyverilog
from pyverilog.vparser.parser import parse
from AST_analyzer import *
from AST_analyzer import AST_analyzer
from multiprocessing import Pool

design_json = "/home/coguest5/LS-benchmark/design.json"

def trace_graph(g, node_dict, ep, trace_set):
    g_nx = nx.DiGraph(g)
    # print(ep)
    # print(list(g_nx.successors(ep)))
    if ep not in trace_set:
        trace_set.add(ep)

    for node in g_nx.successors(ep):
        
        if node in trace_set:
            if node_dict[node].tpe in ['DFF']:
                in_set.add(node)
            continue
        elif node_dict[node].tpe in ['DFF' , 'Input', 'Output', 'Pointer', None]:
            trace_set.add(node)
            if node_dict[node].tpe in ['DFF']:
                in_set.add(node)
        else:
            trace_set.add(node)
            trace_graph(g, node_dict, node, trace_set)
    return

def extract_FF_IO(ep, lines, in_=True):
    if in_:
        in_wire = ""
        for line in lines:
            ff_re_D = re.findall(r"DFF(\S*)_X(\d+) {0} \((.*)\.D\((\S+)\)(.*)\)".format(ep), line)
            if ff_re_D:
                in_wire = ff_re_D[0][3]
        return in_wire
    else:
        out_wire = ""
        for line in lines:
            ff_re_Q = re.findall(r"DFF(\S*)_X(\d+) {0} \((.*)\.Q\((\S+)\)(.*)\)".format(ep), line)
            if ff_re_Q:
                out_wire = ff_re_Q[0][3]
        return out_wire

def get_cone_out(g, node_dict, trace_set):
    out_set = set()
    g_nx = nx.DiGraph(g)
    for node in trace_set:
        if g_nx.in_degree(node) == 0:
            continue
        else:
            for pre in g_nx.predecessors(node):
                if pre not in trace_set:
                    if node_dict[node].tpe not in ['DFF', 'Input', 'Output', 'Pointer', None]:
                        out_set.add(node)
    return out_set

def extract_gate_IO(gate, lines, in_=True):
    out_wire = ""
    if in_:
        assert False
    else:
        for line in lines:
            ff_re_Z = re.findall(r"(\S*)_X(\d+) {0} \((.*)\.(Z|ZN)\((\S+)\)(.*)\)".format(gate), line)
            if ff_re_Z:
                out_wire = ff_re_Z[0][-2]
        return out_wire


def run_all_bit_one_ep(para_lst):
    ep_lst, design_name, g, node_dict, netlist_data, ep_name, save_dir = para_lst[0], para_lst[1], para_lst[2], para_lst[3], para_lst[4], para_lst[5], para_lst[6]
    def_line_lst, module_line_lst, inst_dict, lines = netlist_data[0], netlist_data[1], netlist_data[2], netlist_data[3]

    global trace_set, in_set, trace_set_all
    trace_set = set()
    in_set = set()
    trace_set_all = set()
    cone_out_set_all = set()

    io_lines = []

    for ep in ep_lst:
        trace_set = set()
        trace_graph(g, node_dict, ep, trace_set)
        cone_out_set = get_cone_out(g, node_dict, trace_set)
        cone_out_set_all = cone_out_set_all.union(cone_out_set)
        ep_out_wire = extract_FF_IO(ep, lines, in_=False)
        if ep_out_wire:
            io_lines.append(f"  output {ep_out_wire};\n")
        trace_set_all = trace_set_all.union(trace_set)
        # print(ep)
        
    in_set = in_set - set(ep_lst)

    ### extract the input and output wire of DFFs ###
    
    for in_node in in_set:
        in_wire = extract_FF_IO(in_node, lines, in_=True)
        if in_wire:
            io_lines.append(f"  input {in_wire};\n")
    for out_node in cone_out_set_all:
        out_wire = extract_gate_IO(out_node, lines, in_=False)
        if out_wire:
            io_lines.append(f"  output {out_wire};\n")
    

    ### write the netlist to a new file ###
    idx_lst = []
    idx_lst.extend(def_line_lst)
    idx_lst.extend(module_line_lst)
    for n in trace_set_all:
        if n in inst_dict:
            idx_lst.append(inst_dict[n])

    idx_lst = sorted(idx_lst)

    ## write the netlist to a new file ###
    with open(f'{save_dir}/{ep_name}.syn.v', 'w') as f:
        for idx in idx_lst:
            f.writelines(lines[idx])

    ## modify the first line of netlist file ###

    ## add lines into the second line of the file ###
    with open(f'{save_dir}/{ep_name}.syn.v', 'r') as f:
        lines = f.readlines()
    module_line = lines[0]
    re_module_line = re.findall(r"(^module)(\s+)(\S+)(\s+)\((.*)\)(.*)", module_line)
    signal_line = re_module_line[0][4]
    final_line = re_module_line[0][-1]
    module_line_new = f"module coi ({signal_line}"
    for in_node in in_set:
        in_wire = extract_FF_IO(in_node, lines, in_=True)
        if in_wire:
            module_line_new += f", {in_wire}"
    for ep in ep_lst:
        ep_out_wire = extract_FF_IO(ep, lines, in_=False)
        if ep_out_wire:
            module_line_new += f", {ep_out_wire}"
    for out_node in cone_out_set_all:
        out_wire = extract_gate_IO(out_node, lines, in_=False)
        if out_wire:
            module_line_new += f", {out_wire}"
    module_line_new += f") {final_line}\n"
    lines[0] = module_line_new
    lines.insert(1, ''.join(io_lines))
    with open(f'{save_dir}/{ep_name}.syn.v', 'w') as f:
        f.writelines(lines)

def run_one_design(design_name):
    print('Current Design:', design_name)
    design_dir = f'../data/netlist_clean/{design_name}.syn.v'

    save_dir = f"../netlist_data/"
    with open (f"{save_dir}/{design_name}_graph.pkl", 'rb') as f:
        g = pickle.load(f)
    with open (f"{save_dir}/{design_name}_node_dict.pkl", 'rb') as f:
        node_dict = pickle.load(f)
    with open (f"{save_dir}/{design_name}_dff_set.pkl", 'rb') as f:
        dff_set = pickle.load(f)


    netlist_data = parse_netlist(design_dir)

    ep_lst = []
    with open (f"../json/{design_name}.json", 'r') as f:
        ep_lst_dict = json.load(f)

    if not os.path.exists(f"../ori/vlg/{design_name}"):
        os.mkdir(f"../ori/vlg/{design_name}")

    para_lst = []
    print('# ep: ', len(ep_lst_dict))
    
    for ep, ep_lst in ep_lst_dict.items():
        para_lst.append([ep_lst, design_name, g, node_dict, netlist_data, ep, f"../ori/vlg/{design_name}"])
    
    with Pool(10) as p:
        p.map(run_all_bit_one_ep, para_lst)
        p.close()
        p.join()

        

def parse_netlist(file_path):
    with open (file_path, 'r') as f:
        lines = f.readlines()
    def_line_lst = []
    module_line_lst = []
    inst_dict = {}
    for idx, line in enumerate(lines):
        module_line = re.findall(r"(^module|^endmodule)(.*)", line)
        def_line = re.findall(r"^  (input|wire)(.*)", line)
        inst_line = re.findall(r"^  (\S+)_X(\d+)(\s+)(\S+)(\s+)\(", line)
        if module_line:
            module_line_lst.append(idx)
        if def_line:
            def_line_lst.append(idx)
        if inst_line:
            inst_dict[inst_line[0][3]] = idx


    return [def_line_lst, module_line_lst, inst_dict, lines]




if __name__ == '__main__':
    design_js_path = "/home/coguest5/rtl_repr/data_collect/dataset/json/design_lst/design_lst_p6.json"
    with open (design_js_path, 'r') as f:
        design_lst = json.load(f)
    

    for design in design_lst:
        run_one_design(design)