import json, os, re, copy
from multiprocessing import Pool
from collections import defaultdict
design_hier_json = "./design_hier.json"
design_json = "../../design.json"
LS_dir = "/home/coguest5/LS-benchmark/"


def get_coi_signal(line):
    
    ret_set_tmp = set(re.split(r"[ ]+", line))
    ret_set = set()
    for s in ret_set_tmp:
        s = re.sub(r",$",'', s)
        ps_re = re.findall(r"(\S+)\[(\d+):(\d+)\]$", s)
        ps_re2 = re.findall(r"(\S+)\[(\D+)\[", s)
        ptr_re = re.findall(r"(\S+)\[(\d+)\]$", s)
        if re.findall(r"(\d+)'", s):
            continue
        if re.findall(r"^\[(\d+)\]$", s):
            continue
        if re.findall(r"^\[(\d+):(\d+)\]$", s):
            continue

        if ps_re:
            s = ps_re[0][0]
        if ps_re2:
            s0 = ps_re2[0][0]
            s1 = ps_re2[0][1]
            if re.findall(r"[A-Za-z0-9_\.]", s0):
                ret_set.add(s0)
            if re.findall(r"[A-Za-z0-9_\.]", s1):
                ret_set.add(s1)
            continue

        if ptr_re:
            s = ptr_re[0][0]
        
        if re.findall(r"[A-Za-z0-9_\.]", s):
            ret_set.add(s)

    
    
    return ret_set

def run_one_design(bench, design_name):

    print(bench, design_name)
    design_dir = f"/home/coguest5//LS-benchmark/{bench}/coi/{design_name}_flatten.v"
    with open (design_dir, 'r') as f:
        lines = f.readlines()
    
    def_dict, comb_assign_dict = parse_vlg(lines)

    reg_lst_path = f"/home/coguest5/CircuitFusion/data_collectvlg/data/reg_lst/{design_name}.json"
    with open(reg_lst_path, 'r') as f:
        reg_lst = json.load(f)

    para_lst = []
    for ep in reg_lst:
        para_lst.append((ep, lines, design_name, def_dict, comb_assign_dict))

    for idx, para in enumerate(para_lst):
        run_one_ep(para)

    # run_one_ep(("ctrl", lines, design_name, def_dict, comb_assign_dict))


def run_one_design_parallel(bench, design_name):
    print(bench, design_name)
    design_dir = f"/home/coguest5/LS-benchmark/{bench}/coi/{design_name}_flatten.v"
    with open (design_dir, 'r') as f:
        lines = f.readlines()
    
    def_dict, comb_assign_dict = parse_vlg(lines)


    reg_lst_path = f"/home/coguest5/CircuitFusion/data_collectvlg/data/reg_lst/{design_name}.json"
    with open(reg_lst_path, 'r') as f:
        reg_lst = json.load(f)
    
    para_lst = []
    for ep in reg_lst:
        para = (ep, lines, design_name, def_dict, comb_assign_dict)
        para_lst.append(para)


    with Pool(50) as p:
        p.map(run_one_ep, para_lst)
        p.close()
        p.join()

def extract_signal(signal_name):
    signal_name = re.sub(r" ", "", signal_name)
    signal_name = re.sub(r"\[(.*)\]", "", signal_name)
    return signal_name

def extract_comb_assign(LHS):
    ret_lst = []
    concat_re = re.findall(r"{(.*)}", LHS)
    if concat_re:
        signal_lst = re.split(r",", concat_re[0])
        for s in signal_lst:
            s = extract_signal(s)
            ret_lst.append(s)
    else:
        s = extract_signal(LHS)
        ret_lst.append(s)
    
    return ret_lst
    

def parse_vlg(lines):
    def_dict = dict() ## dict: signal_name -> idx
    ## dict: signal_name -> [idx, fanin_lst]
    comb_assign_dict = defaultdict(list) 
    seq_assign_dict = defaultdict(list)

    for idx, line in enumerate(lines):

        ## 1. signal definition
        def_line = re.findall(r"^  (reg|wire|input|output|inout)(\s*)(\[(.*)\])*(\s*)(.*)(\s*);", line)
        if def_line:
            def_line_l2 = re.findall(r"(.*)=(.*)", def_line[0][5])
            if def_line_l2:
                signal_name = def_line_l2[0][0]
            else:
                signal_name = def_line[0][5]
            signal_name = extract_signal(signal_name)
            def_dict[signal_name] = idx
        
        ## 2. comb assign
        comb_line = re.findall(r"^  (wire|assign)(.*)(\[(.*)\])*(\s+)=(\s+)(.*);", line)
        if comb_line:
            LHS = comb_line[0][1]
            RHS = comb_line[0][-1]
            LHS_lst = extract_comb_assign(LHS)
            RHS_set = get_coi_signal(RHS)
            for LHS in LHS_lst:
                comb_assign_dict[LHS].append((idx, RHS_set))
            # comb_assign_dict[LHS].append((idx, RHS_set))


    return def_dict, comb_assign_dict

def run_one_ep(para_lst):
    ep = para_lst[0]
    lines = para_lst[1]
    design_name = para_lst[2]

    def_dict, comb_assign_dict = para_lst[3], para_lst[4]


    todo_set = set()
    add_set, un_add_set = set(), set()
    
    in_set, out_set = set(), set()
    coi_dict = {}
    ep_dict = {}

    ### 1. get seq assign
    for idx, line in enumerate(lines):

        def_ep_line = re.findall(r"^  reg(.*)(\s+){0}(\s*);".format(ep), line)
        if def_ep_line:
            ep_dict[0] = f'  output reg {def_ep_line[0][0]} {ep};\n'
            out_set.add(ep)

        assign_seq_line1 = re.findall(r"{0}(\s*)(\[(.*)\])*(\s*)<=(\s*)(.*)(\s*)(\[(.*)\]);".format(ep), line)
        assign_seq_line2 = re.findall(r"{0}(\s*)(\[(.*)\])*(\s*)<=(\s*)(.*);".format(ep), line)
        assign_seq_rest = re.findall(r" {0}(.*)<=(.*);".format(ep), line)
        if assign_seq_line1:
            RHS = assign_seq_line1[0][-4]
            ep_dict[-1] = f"  always @(posedge clk) begin\n    {ep} <= {RHS};\n  end"
            RHS_set = get_coi_signal(RHS)
        elif assign_seq_line2:
            RHS = assign_seq_line2[0][-1]
            ep_dict[-1] = f"  always @(posedge clk) begin\n    {ep} <= {RHS};\n  end"
            RHS_set = get_coi_signal(RHS)
        elif assign_seq_rest and (not assign_seq_line1) and (not assign_seq_line2):
            # print(line)
            # print(ep)
            # assert False
            pass
    try:
        for s in RHS_set:
            def_idx = def_dict[s]
            def_line = lines[def_idx]
            def_line_re = re.findall(r"^  (reg|wire|input|output|inout)(\s*)(\[(.*)\])*(\s*)(.*)(\s*);", def_line)
            if def_line_re:
                if re.findall(r"^  input", def_line):
                    in_set.add(s)
                elif re.findall(r"^  output", def_line):
                    out_set.add(s)
                elif re.findall(r"^  reg", def_line):
                    def_line = re.sub(r"^  reg", "  input", def_line)
                    in_set.add(s)
            coi_dict[def_idx] = def_line

            comb_line = comb_assign_dict[s]
            for pair in comb_line:
                comb_idx = pair[0]
                comb_set = pair[1]
                coi_dict[comb_idx] = lines[comb_idx]
                
                add_set.add(ep)
                un_add_set = un_add_set | comb_set




        while True:

            todo_set = un_add_set-add_set
            before_len = len(todo_set)

            for s in todo_set.copy():
                if s not in def_dict:
                    continue
                def_idx = def_dict[s]
                def_line = lines[def_idx]
                def_line_re = re.findall(r"^  (reg|wire|input|output|inout)(\s*)(\[(.*)\])*(\s*)(.*)(\s*);", def_line)
                if def_line_re:
                    if re.findall(r"^  input", def_line):
                        in_set.add(s)
                    elif re.findall(r"^  output", def_line):
                        out_set.add(s)
                    elif re.findall(r"^  reg", def_line):
                        def_line = re.sub(r"^  reg", "  input", def_line)
                        in_set.add(s)
                coi_dict[def_idx] = def_line
                
                comb_line = comb_assign_dict[s]
                if not comb_line:
                    add_set.add(s)
                else:
                    for pair in comb_line:
                        comb_idx = pair[0]
                        comb_set = pair[1]
                        coi_dict[comb_idx] = lines[comb_idx]
                        
                        add_set.add(s)
                        un_add_set = un_add_set | comb_set

            # after_len = len(un_add_set)
            # if len(un_add_set) == len(add_set):
            after_len = len(un_add_set - add_set)
            # print(after_len, before_len)
            if (after_len <= 0) or (before_len == after_len):
                break
        
        if not os.path.exists(f"../data/ep_vlg/{design_name}"):
            os.makedirs(f"../data/ep_vlg/{design_name}")
        with open (f"../data/ep_vlg/{design_name}/{ep}.v", 'w') as f:

            module_line = 'module coi (clk, rst, '
            for in_pin in in_set:
                module_line += in_pin + ', '
            for idx, out_pin in enumerate(out_set):
                if idx != len(out_set)-1:
                    module_line += out_pin + ', '
                else:
                    module_line += out_pin
            module_line += ');\n'
            f.writelines(module_line)   
            f.writelines("  input clk;\n") 
            f.writelines("  input rst;\n") 
            f.writelines(ep_dict[0])
            for i in sorted (coi_dict): 
                f.writelines(coi_dict[i])
            f.writelines(ep_dict[-1])
            f.writelines('\nendmodule')

            print(design_name + "  " + ep + "  Finish!")

    except:
        print(ep + "  Error!")
        return






if __name__ == '__main__':
    bench_dict = {}
    design_json = "/home/coguest5/rtl_repr/LS-benchmark/design.json"
    bench_list = ['itc','opencores','VexRiscv','chipyard', 'riscvcores','NVDLA']
    # bench_list = ['itc']

    # bench_list = ['opencores']
    # bench_list = ['riscvcores']

    # bench_list = ['VexRiscv']
    # bench_list = ['chipyard']
    
    # bench_list = ['NVDLA']

    with open("/home/coguest5/CircuitFusion/data_collect/dataset/json/design_lst/design_lst.json", 'r') as f:
        design_lst = json.load(f)

    design_name =  "spi"
    design_name = ""

    for bench in bench_list:
        with open(design_json, 'r') as f:
            design_data = json.load(f)
            dir_ = design_data[bench]
            for k, v in dir_.items():
                if k not in design_lst:
                    continue
                if design_name:
                    if k == design_name:
                        # run_one_design(bench, design_name)
                        run_one_design_parallel(bench, design_name)
                else:
                    # run_one_design(bench, k)
                    run_one_design_parallel(bench, k)
    
    


    