import json, os, re, copy
from multiprocessing import Pool
design_hier_json = "./design_hier.json"
design_json = "../../design.json"
LS_dir = "/home/coguest5/LS-benchmark/"


def get_coi_signal(line):
    
    ret_set_tmp = set(re.split(r"[ ]+", line))
    ret_set = set()
    for s in ret_set_tmp:
        s = re.sub(r",$",'', s)
        ps_re = re.findall(r"(\S+)\[(\d+):(\d+)\]$", s)
        ps_re2 = re.findall(r"(\S+)\[(\D+)\[", s)
        ptr_re = re.findall(r"(\S+)\[(\d+)\]$", s)
        if re.findall(r"(\d+)'", s):
            continue
        if re.findall(r"^\[(\d+)\]$", s):
            continue
        if re.findall(r"^\[(\d+):(\d+)\]$", s):
            continue

        if ps_re:
            s = ps_re[0][0]
        if ps_re2:
            s0 = ps_re2[0][0]
            s1 = ps_re2[0][1]
            # print(ps_re2)
            if re.findall(r"[A-Za-z0-9_\.]", s0):
                ret_set.add(s0)
            if re.findall(r"[A-Za-z0-9_\.]", s1):
                ret_set.add(s1)
            continue

        if ptr_re:
            s = ptr_re[0][0]
        
        if re.findall(r"[A-Za-z0-9_\.]", s):
            ret_set.add(s)

    
    
    return ret_set


def extra_one_signal(signal_name, lines, not_ep=True):
    coi_signal = set()
    for idx, line in enumerate(lines):

        signal_name = re.sub(r"\\\\", "", signal_name)
        signal_name = re.sub(r"\\", "", signal_name)
        def_line = re.findall(r"^  (reg|wire|input|output|inout)(.*)(\s*){0}(\s*);".format(signal_name), line)
        def_reg_line = re.findall(r"^  reg(.*)(\s+){0}(\s*);".format(signal_name), line)

        assign_seq_line2 = re.findall(r"{0}(\s*)(\[(.*)\])*(\s+)<=(\s+)(.*);".format(signal_name), line)
        assign_seq_line1 = re.findall(r"{0}(\s*)(\[(.*)\])*(\s*)<=(\s*)(.*)(\s*)(\[(.*)\]);".format(signal_name), line)


        assign_comb_line = re.findall(r"{0}(\[(.*)\])*(\s+)=(\s+)(.*);".format(signal_name), line)

        

        if def_reg_line and not_ep:
            line_new = f"  input {def_reg_line[0][1]} {signal_name};\n"
            in_set.add(signal_name)
            coi_dict[idx] = line_new
            return
        elif def_reg_line and (not not_ep):
            ep_dict[0] = f'  output reg {def_reg_line[0][0]} {signal_name};\n'
            out_set.add(signal_name)
        if not not_ep:
            if assign_seq_line1:
                RHS = assign_seq_line1[0][-4]
                ep_dict[-1] = f"  always @(posedge clk) begin\n    {signal_name} <= {RHS};\n  end"
                coi_signal = coi_signal.union(get_coi_signal(RHS))
            elif assign_seq_line2:
                RHS = assign_seq_line2[0][-1]
                ep_dict[-1] = f"  always @(posedge clk) begin\n    {signal_name} <= {RHS};\n  end"
                coi_signal = coi_signal.union(get_coi_signal(RHS))
            # else:
            #     print(line)
            #     assert False

        if def_line and not_ep:
            coi_dict[idx] = line

            if 'input' in line:
                in_set.add(signal_name)
            elif 'output' in line:
                out_set.add(signal_name)
        
        elif assign_comb_line:
            coi_dict[idx] = line
            coi_signal = coi_signal.union(get_coi_signal(assign_comb_line[0][-1]))

    # rm_lst = ['?', ':', '&', '|', '~', '{', '}']
    # for rm_s in rm_lst:
    #     if rm_s in coi_signal:
    #         coi_signal.remove(rm_s)
    global un_add_set
    un_add_set = un_add_set.union(coi_signal)





def run_one_design(bench, design_name):

    print(bench, design_name)
    design_dir = f"/home/coguest5/LS-benchmark/{bench}/coi/{design_name}_flatten.v"
    with open (design_dir, 'r') as f:
        lines = f.readlines()

    reg_lst_path = f"/home/coguest5/CircuitFusion/data_collectvlg/data/reg_lst/{design_name}.json"
    with open(reg_lst_path, 'r') as f:
        reg_lst = json.load(f)
    
    para_lst = []
    for ep in reg_lst:
        para_lst.append((ep, copy.deepcopy(lines), design_name))

    for idx, para in enumerate(para_lst):
        print(reg_lst[idx])
        run_one_ep(para)
        # exit()
    

def run_one_design_parallel(bench, design_name):
    
    print(bench, design_name)
    design_dir = f"/home/coguest5/LS-benchmark/{bench}/coi/{design_name}_flatten.v"
    with open (design_dir, 'r') as f:
        lines = f.readlines()

    reg_lst_path = f"/home/coguest5/CircuitFusion/data_collectvlg/data/reg_lst/{design_name}.json"
    with open(reg_lst_path, 'r') as f:
        reg_lst = json.load(f)
    
    para_lst = []
    for ep in reg_lst:
        para_lst.append((ep, copy.deepcopy(lines), design_name))

    with Pool(20) as p:
        p.map(run_one_ep, para_lst)
        p.close()
        p.join()
    # print(reg_lst)
    # for para in para_lst:
    #     run_one_ep(para)
        # exit()


def run_one_ep(para_lst):  
    ep = para_lst[0]
    lines = para_lst[1]
    design_name = para_lst[2]

    ep_ori = ep
    global add_set, un_add_set
    add_set = set()
    add_set.add(ep)
    add_set = add_set.union(set([':', '?', '{', '}', '~', '|']))
    un_add_set = set()

    
    global coi_dict
    coi_dict = {}

    global ep_dict
    ep_dict = {}

    global in_set, out_set
    in_set, out_set = set(), set()

    try:
        extra_one_signal(ep, lines, not_ep=False)

        while True:
            todo_set = un_add_set-add_set
            before_len = len(un_add_set)
            for ep in todo_set:
                extra_one_signal(ep, lines, not_ep=True)
                add_set.add(ep)
                # un_add_set.remove(ep)
            after_len = len(un_add_set)
            if before_len == after_len:
                break


        if not os.path.exists(f"../data/ep_vlg/{design_name}"):
            os.makedirs(f"../data/ep_vlg/{design_name}")
        

        
        with open (f"../data/ep_vlg/{design_name}/{ep_ori}.v", 'w') as f:

            module_line = 'module coi (clk, rst, '
            for in_pin in in_set:
                module_line += in_pin + ', '
            for idx, out_pin in enumerate(out_set):
                if idx != len(out_set)-1:
                    module_line += out_pin + ', '
                else:
                    module_line += out_pin
            module_line += ');\n'
            f.writelines(module_line)   
            f.writelines("  input clk;\n") 
            f.writelines("  input rst;\n") 
            f.writelines(ep_dict[0])
            for i in sorted (coi_dict): 
                f.writelines(coi_dict[i])
            f.writelines(ep_dict[-1])
            f.writelines('\nendmodule')

            print(design_name + "  " + ep_ori + "  Finish!")

    except:
        print(ep_ori + "  Error!")




if __name__ == '__main__':
    bench_dict = {}
    design_json = "/home/coguest5/rtl_repr/LS-benchmark/design.json"
    bench_list = ['itc','opencores','VexRiscv','chipyard', 'riscvcores','NVDLA']
    bench_list = ['itc']

    bench_list = ['opencores']

    bench_list = ['VexRiscv']
    bench_list = ['chipyard']
    bench_list = ['riscvcores']
    bench_list = ['NVDLA']

    design_name =  ""

    for bench in bench_list:
        with open(design_json, 'r') as f:
            design_data = json.load(f)
            dir_ = design_data[bench]
            for k, v in dir_.items():
                if design_name:
                    if k == design_name:
                        # run_one_design(bench, design_name)
                        run_one_design_parallel(bench, design_name)
                else:
                    # run_one_design(bench, k)
                    run_one_design_parallel(bench, k)
    
    


    