import re, json
from multiprocessing import Pool

def merge_lines_para(para_lst):
    input_file, output_file = para_lst[0], para_lst[1]
    with open(input_file, 'r') as f:
        lines = f.readlines()

    lines = merge_lines_first_time(lines, output_file)
    while True:
        pre_len = len(lines)
        lines = merge_lines_once(lines, output_file)
        post_len = len(lines)
        # print(pre_len, post_len)
        if pre_len == post_len:
            break

    with open(output_file, 'w') as f:
        f.writelines(lines)
    print(f"Finish {output_file}")

def merge_lines(input_file, output_file):
    with open(input_file, 'r') as f:
        lines = f.readlines()

    lines = merge_lines_first_time(lines, output_file)
    while True:
        pre_len = len(lines)
        lines = merge_lines_once(lines, output_file)
        post_len = len(lines)
        print(pre_len, post_len)
        if pre_len == post_len:
            break

    with open(output_file, 'w') as f:
        f.writelines(lines)

def merge_lines_first_time(lines, output_file):
    lines_fixed = lines.copy()

    for idx, line in enumerate(lines):
        comment_line = re.findall(r"^//", line)
        module_line = re.findall(r"(^module|^endmodule)(.*)", line)
        true_line = re.findall(r"^  (input|output|wire|(\S+_X(\d+)))(.*)", line)
        if comment_line:
            lines_fixed[idx] = ''
        if not comment_line and not module_line and not true_line:
            if re.findall(r"(^module|^endmodule)(.*)", lines_fixed[idx-1]):
                lines_fixed[idx-1] = lines_fixed[idx-1].strip() + ' ' + line.strip() + '\n'
            else:
                lines_fixed[idx-1] = '  ' + lines_fixed[idx-1].strip() + ' ' + line.strip() + '\n'
            lines_fixed[idx] = ''

    with open(output_file, 'w') as f:
        f.writelines(lines_fixed)
    with open(output_file, 'r') as f:
        lines_fixed = f.readlines()
    return lines_fixed

def merge_lines_once(lines, output_file):
    lines_fixed = lines.copy()

    for idx, line in enumerate(lines):
        comment_line = re.findall(r"^//", line)
        module_line = re.findall(r"(^module|^endmodule)(.*)", line)
        true_line = re.findall(r"^  (input|output|wire|(\S+_X(\d+)))(.*)", line)
        if comment_line:
            lines_fixed[idx] = ''
        if not comment_line and not module_line and not true_line:
            if re.findall(r"(^module|^endmodule)(.*)", lines_fixed[idx-1]):
                lines_fixed[idx-1] = lines_fixed[idx-1].strip() + ' ' + line.strip() + '\n'
            else:
                lines_fixed[idx-1] = '  ' + lines_fixed[idx-1].strip() + ' ' + line.strip() + '\n'
            lines_fixed[idx] = ''
            break

    with open(output_file, 'w') as f:
        f.writelines(lines_fixed)
    with open(output_file, 'r') as f:
        lines_fixed = f.readlines()
    return lines_fixed
    


    


# input_file = "/home/coguest5/rtl_repr/data_collect/dataset_net/data/netlist/b14_b14_TYP.syn.v"
# output_file = "./b14_b14_TYP.syn.v"
# merge_lines(input_file, output_file)



def run_all(bench, design_name=None):
    design_json = "/home/coguest5/AST_analyzer/LS-benchmark/design.json"
    design_lst = "/home/coguest5/rtl_repr/data_collect/dataset/json/design_lst/design_lst.json"
    with open(design_lst, 'r') as f:
        design_lst = json.load(f)
    design_set = set(design_lst)
    with open(design_json, 'r') as f:
        design_data = json.load(f)
        bench_data = design_data[bench]
    for k, v in bench_data.items():
        if k not in design_set:
            continue

        design_top = v[0]
        input_file = f"../data/netlist/{design_top}_{k}_TYP.syn.v"
        output_file = f"../data/netlist_clean/{k}.syn.v"
        if design_name:
            if k == design_name:
                merge_lines(input_file, output_file)
        else:
            merge_lines(input_file, output_file)

def run_all_parallel(bench):
    
    design_json = "/home/coguest5/AST_analyzer/LS-benchmark/design.json"
    design_lst = "/home/coguest5/rtl_repr/data_collect/dataset/json/design_lst/design_lst.json"
    with open(design_lst, 'r') as f:
        design_lst = json.load(f)
    design_set = set(design_lst)
    with open(design_json, 'r') as f:
        design_data = json.load(f)
        bench_data = design_data[bench]
    
    para_lst = []
    for k, v in bench_data.items():
        if k not in design_set:
            continue

        design_top = v[0]
        input_file = f"../data/netlist/{design_top}_{k}_TYP.syn.v"
        output_file = f"../data/netlist_clean/{k}.syn.v"
        para_lst.append((input_file, output_file))
    
    print(para_lst)

    with Pool(50) as p:
        p.map(merge_lines_para, para_lst)
        p.close()
        p.join()

if __name__ == '__main__':
    bench_list_all = ['itc','opencores','VexRiscv', 'chipyard', 'riscvcores','NVDLA']


    # for bench in bench_list_all:
    #     run_all(bench, '')
    # run_all("chipyard", "TinyRocket")
    
    # run_all('itc', "")
    # run_all('opencores', "")
    # run_all('VexRiscv', "")
    # run_all('chipyard', "")
    # run_all('riscvcores', "")
    # run_all('NVDLA', "")

    # run_all_parallel('itc')
    # run_all_parallel('opencores')
    # run_all_parallel('VexRiscv')
    # run_all_parallel('chipyard')
    # run_all_parallel('riscvcores')
    run_all_parallel('NVDLA')
