import os
from abc import abstractmethod
from copy import deepcopy
import re
import numpy as np


class EntryFormat:
    def __init__(self, ruleList: list):
        self.ruleList = ruleList

    def __call__(self, entry: str):
        for rule in self.ruleList:
            match, output = rule(entry)
            if match:
                return {} if output == None else output
        return {}


class EntryRule:
    @abstractmethod
    def __call__(self, entry: str) -> (bool, dict):
        pass


class AlwaysMatchedRule(EntryRule):
    def __call__(self, entry: str) -> (bool, dict):
        return (True, {})

class SkipRule(AlwaysMatchedRule):
    def __call__(self, entry: str) -> (bool, dict):
        return (
            True,
            {
                "entry" : entry,
            }
        )

class AlwaysDismatchedRule(EntryRule):
    def __call__(self, entry: str) -> (bool, dict):
        return (False, {})


class PrefixRule(EntryRule):
    def __init__(self, prefix: str):
        self.prefix = prefix

    def __call__(self, entry: str) -> (bool, dict):
        return (
            entry.strip().startswith(self.prefix),
            {}
        )


PrefixIgnoreRule = PrefixRule # alias

class PrefixSkipRule(PrefixRule):
    def __call__(self, entry: str) -> (bool, dict):
        return (
            entry.strip().startswith(self.prefix),
            {
                "entry": entry
            }
        )


class RegExRule(EntryRule):
    def __init__(self, regex: str, index_group_pairs: dict):
        self.regex = re.compile(regex)
        self.index_group_pairs = index_group_pairs
    
    def __call__(self, entry: str) -> (bool, dict):
        match = self.regex.search(entry)
        if match == None:
            return (False, {})
        output = {}
        for index, group in self.index_group_pairs.items():
            if group == -1:
                output[index] = entry
            else:
                output[index] = match.group(group)
        return (True, output)


class RuleGroup:
    def __init__(
        self,
        entrance_rule: EntryRule,
        exit_rule: EntryRule,
        ruleList: list,
        dismatch_policy = "exit_rule"
    ):
        self.entrance_rule = entrance_rule
        self.exit_rule = exit_rule
        self.ruleList = ruleList
        self.dismatch_policy = dismatch_policy

    def access(self, entry: str) -> bool:
        return self._enter(entry)

    def _enter(self, entry: str) -> bool:
        return self.entrance_rule(entry)[0]

    def _exit(self, entry: str) -> bool:
        return self.exit_rule(entry)[0]

    def _dismatch(self, entry: str) -> bool:
        def dismatch_policy_exit_rule(entry: str):
            return self._exit(entry)

        def dismatch_policy_exit(entry: str):
            return True 

        def dismatch_policy_stay(entry: str):
            return False

        dismatch_policy_set = {
            "default": dismatch_policy_exit_rule,
            "exit_rule": dismatch_policy_exit_rule,
            "exit": dismatch_policy_exit,
            "stay": dismatch_policy_stay,
        }

        dismatch_policy = dismatch_policy_set.get(
            self.dismatch_policy,
            dismatch_policy_set["default"]
        )

        return dismatch_policy(entry)

    def __call__(self, entry: str):
        for rule in self.ruleList:
            match, output = rule(entry)
            if match:
                return (
                    {} if output == None else output,
                    self._exit(entry)
                )
        
        return (
            {},
            self._dismatch(entry)
        )
        

class EntryFormatWithRuleGroups(EntryFormat):
    def __init__(self, ruleGroups: list):
        self.ruleGroups = ruleGroups
        self._nowGroup = None

    def inGroup(self) -> bool:
        return self._nowGroup != None

    def quitGroup(self):
        self._nowGroup = None

    def __call__(self, entry: str):
        if self._nowGroup == None:
            for ruleGroup in self.ruleGroups:
                if ruleGroup.access(entry):
                    self._nowGroup = ruleGroup
                    break
            else:
                # no available rule group
                return {}
        
        output, exited = self._nowGroup(entry)
        if exited:
            self.quitGroup()
        return output
    
def read_benchmark(database, benchmark, args=None):
    print("read database from benchmark %s" % (benchmark))
    database["benchmark_dir"] = os.path.join(os.path.dirname(os.path.abspath('.')), f'benchmark/{benchmark}')
    return read_benchmark_from_def(database, os.path.join(database["benchmark_dir"], f'{benchmark}.def'), args)

def read_benchmark_from_def(database, benchmark, args=None):
    node_info = {}
    node_info_raw_id_name = {}

    node_cnt = 0
    port_info = {}
    standard_cell_name = []
    port_to_net_dict = {}


    database["def_file"] = os.path.basename(benchmark)

    design_name = os.path.splitext(database["def_file"])[0]

    database["files"] = {
        "def_file": database["def_file"],
        "lef_file": "%s.lef" % design_name,
        "v_file": "%s.v" % design_name
    }

    database["nodes"] = {}
    database["macros"] = []


    read_lef(
        database,
        os.path.join(
            database["benchmark_dir"],
            database["files"]["lef_file"]
        )
    )
    
    read_def(
        database,
        os.path.join(
            database["benchmark_dir"],
            database["files"]["def_file"]
        )
    )

    # compute area each macro type
    macro_type_area = {}
    for macro_type in database["macro_size"]:
        size_x, size_y = database["macro_size"][macro_type]
        area = size_x * size_y
        macro_type_area[macro_type] = area

    # compute area each cell
    cell_area_dict = {}
    cell_total_area = 0
    for cell in database["nodes"]:
        cell_type = database["nodes"][cell]["node_type"]
        area = macro_type_area[cell_type]
        cell_area_dict[cell] = area
        cell_total_area += area
    
    cell_lst = list(cell_area_dict.keys())
    cell_lst = sorted(cell_lst, key=lambda x: cell_area_dict[x], reverse=True)

    if args is None:
        n_macro = 512
    else:
        n_macro = args.n_macro
    
    macro_lst = cell_lst[:n_macro]
    standard_cell_name = cell_lst[n_macro:]


    max_height = 0
    max_width = 0
    ratio_x, ratio_y = get_scaling_ratio(database)
    for id, macro in enumerate(macro_lst):
        # scaling
        place_x = eval(database["nodes"][macro]['x']) * ratio_x
        place_y = eval(database["nodes"][macro]['y']) * ratio_y

        macro_type = database["nodes"][macro]["node_type"]
        size_x, size_y = database["macro_size"][macro_type]

        max_height = max(max_height, size_y + place_y)
        max_width = max(max_width, size_x + place_x)
        node_info[macro] = {"id": id, "x": size_x, "y": size_y}
        node_info[macro]["raw_x"] = place_x
        node_info[macro]["raw_y"] = place_y
        node_info_raw_id_name[id] = macro


    node_cnt = len(node_info)
    assert node_cnt == n_macro

    v_file = open(os.path.join(database["benchmark_dir"],database["files"]["v_file"]), 'r')
    
    net_info = read_v(v_file, node_info, database)
    net_cnt = len(net_info)

    placedb_info = {
        'node_info' : node_info,
        'node_info_raw_id_name' : node_info_raw_id_name,
        'node_cnt' : node_cnt,
        'port_info' : port_info,
        'net_info' : net_info,
        'net_cnt' : net_cnt,
        'max_height' : max_height,
        'max_width' : max_width,
        'standard_cell_name' : standard_cell_name,
        'port_to_net_dict' : port_to_net_dict,
        'cell_total_area' : cell_total_area,
    }
    return placedb_info
    

def read_v(fopen, node_info, database):
    net_info = {}
    net_cnt = 0
    flag = 0
    for line in fopen.readlines():
        if 'wire' in line:
            line_ls = line.split(" ")
            net_name = line_ls[1].split(";")[0]
            net_info[net_name] = {}
            net_info[net_name]["nodes"] = {}
            net_info[net_name]["ports"] = {}

        if 'Start cells' in line:
            flag = 1
            continue
        if flag == 1:
            if line == '\n':
                break
            pattern = r"\.(\w+)\((.*?)\)"
            matches = re.findall(pattern=pattern, string=line)
            line_ls = line.split(" ")
            node_type = line_ls[0]
            node_name = line_ls[1]

            for pin_net in matches:
                pin = pin_net[0]
                net = pin_net[1]
                if node_name in node_info.keys() and net in net_info.keys():
                    x_offset, y_offset = database["pin_offset"][node_type][pin]
                    net_info[net]["nodes"][node_name] = {
                        "x_offset": x_offset,
                        "y_offset": y_offset,
                    }


    for net_name in list(net_info.keys()):
        if len(net_info[net_name]["nodes"]) <= 1:
            net_info.pop(net_name)
    for net_name in net_info:
        net_info[net_name]['id'] = net_cnt
        net_cnt += 1
    return net_info

def read_lef(database, lef_file):
    macro_start_rule = RegExRule(
        r"MACRO\s+(\w+)",
        {
            "macro_name": 1,
        }
    )
    macro_or_pin_end_rule = RegExRule(
        r"END\s+(\w+)",
        {
            "macro_name_or_pin_name": 1,
        }
    )
    macro_size_rule = RegExRule(
        r"SIZE\s(\d+(\.\d+)?) BY (\d+(\.\d+)?)",
        {
            "size_x" : 1,
            "size_y" : 3,
        }
    )
    macro_pin_start_rule = RegExRule(
        r"(PIN)\s+(\w+)",
        {
            "pin_start" : 1, 
            "pin_name" : 2,
        }
    )
    macro_pin_offset_rule = RegExRule(
        r"RECT\s(-?\d+(\.\d+)?)\s(-?\d+(\.\d+)?)\s(-?\d+(\.\d+)?)\s(-?\d+(\.\d+)?)",
        {
            'x1' : 1,
            'y1' : 3,
            'x2' : 5,
            'y2' : 7,
        }
    )

    macro_rule_group = RuleGroup(
        macro_start_rule,
        AlwaysDismatchedRule(),
        [
            macro_start_rule,
            macro_or_pin_end_rule,
            macro_size_rule,
            macro_pin_start_rule,
            macro_pin_offset_rule,
            RegExRule(
                r"CLASS\s+(\w+)\s+;",
                {
                    "class": 1,
                }
            ),
            SkipRule()
        ]
    )

    other_rule_group = RuleGroup(
        AlwaysMatchedRule(),
        AlwaysMatchedRule(),
        [
            SkipRule()
        ]
    )

    lef_ent_format = EntryFormatWithRuleGroups(
        [
            macro_rule_group,
            other_rule_group
        ]
    )

    assert lef_file is not None and os.path.exists(lef_file)
    database["lef_macros"] = {}
    database["lef_origin"] = [""]
    database["macro_size"] = {}
    database["pin_offset"] = {}
    macro_name = None
    pin_name = None
    pin_flag = False
    with open(lef_file, "r") as f:
        for line in f:
            output = lef_ent_format(line)
            if not lef_ent_format.inGroup():
                database["lef_origin"][-1] += line
                continue

            if macro_name is None:
                macro_name = output.get("macro_name", None)
                if macro_name is not None:
                    database["lef_macros"][macro_name] = line
                    database["lef_origin"].append("")
                    database["pin_offset"][macro_name] = {}
                continue
            
            if "macro_name_or_pin_name" in output.keys():
                if pin_flag:
                    pin_flag = False

                    min_pin_x = np.min(pin_x)
                    max_pin_x = np.max(pin_x)
                    min_pin_y = np.min(pin_y)
                    max_pin_y = np.max(pin_y)
                    pin_name = output["macro_name_or_pin_name"]
                    database["pin_offset"][macro_name][pin_name] = ((min_pin_x + max_pin_x)/2,
                                                                    (min_pin_y + max_pin_y)/2)
                    del pin_x
                    del pin_y
                    pin_flag = False
                else:
                    database["lef_macros"][macro_name] += line
                    if output["macro_name_or_pin_name"] == macro_name:
                        lef_ent_format.quitGroup()
                        macro_name = None
            elif "size_x" in output.keys():
                database["macro_size"][macro_name] = (eval(output['size_x']), eval(output['size_y']))
            elif "class" in output.keys():
                # rule: CLASS (CORE|BLOCK)
                c = output["class"]
                if c == "CORE":
                    database["lef_macros"][macro_name] += line
                else:
                    database["lef_macros"][macro_name] += \
                        line.replace(c, "CORE", 1)
            elif "pin_start" in output.keys():
                pin_flag = True
                pin_x = []
                pin_y = []
            elif "x1" in output.keys():
                if pin_flag:
                    x1, y1 = eval(output["x1"]), eval(output["y1"])
                    x2, y2 = eval(output["x2"]), eval(output["y2"])
                    pin_x.append(x1)
                    pin_x.append(x2)
                    pin_y.append(y1)
                    pin_y.append(y2)
            else:
                # rule: skip
                database["lef_macros"][macro_name] += line

def read_def(database, def_file):
    
    
    component_start_rule = RegExRule(
        r"COMPONENTS\s+(\d+)\s*;",
        {
            "num_comps": 1,
        }
    )
    component_end_rule = RegExRule(
        r"END\s+COMPONENTS",
        {}
    )
    design_rule = RegExRule(
        r"\s*DESIGN\s+([a-z_A-Z]+)\s+([a-z_A-Z]+)\s+(\d+\.?\d+)\s*;\s*\n?",
        {
            "entry" : 0,
            "key" : 1,
            "value" : 3,
        }
    )
    diearea_rule = RegExRule(
        r"DIEAREA\s+\(\s*(\d+)\s*(\d+)\s*\)\s*\(\s*(\d+)\s*(\d+)\s*\)\s*;\s*\n?",
        {
            # "entry" : 0,
            "lower_x" : 1,
            "lower_y" : 2,
            "upper_x" : 3,
            "upper_y" : 4,
        }
    )
    row_rule = RegExRule(
        r"ROW\s+coreROW_(\d+)\s+core\s+(\d+)\s+(\d+)\s+N\s+DO\s+(\d+)\s+BY\s+(\d+)\s+STEP\s+(\d+)\s+(\d+)\s+;",
        {
            "id" : 1,
            "row_x" : 2,
            "row_y" : 3,
            "after_do" : 4,
            "after_by" : 5,
            "step_x" : 6,
            "step_y" : 7,
        }
    )
    component_rule_group = RuleGroup(
        component_start_rule,
        component_end_rule,
        [
            component_start_rule,
            component_end_rule,
            RegExRule(
                r"(-)\s+(\w+)\s+(\w+)",
                {
                    "head": 1,
                    "node_name": 2,
                    "node_type": 3,
                }
            ),
            RegExRule(
                r"(\+)\s+(\w+)\s*\(\s*([+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)\s*([+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?)\s*\)\s*(\w+)\s+;",
                {
                    "head": 1,
                    "state": 2,
                    "x": 3,
                    "y": 7,
                    "dir": 11,
                }
            ),
        ]
    )

    other_rule_group = RuleGroup(
        
        AlwaysMatchedRule(),
        AlwaysMatchedRule(),
        [
            design_rule,
            diearea_rule,
            row_rule,
            SkipRule()
        ]
    )

    def_ent_format = EntryFormatWithRuleGroups(
        [
            component_rule_group,
            other_rule_group
        ]
    )

    assert def_file is not None and os.path.exists(def_file)
    database["def_origin"] = [""]
    database["design_config"] = {}
    database["diearea_rect"] = []
    database["row"] = {}
    with open(def_file, "r") as f:
        for line in f:
            output = def_ent_format(line)
            if output.get("entry", None) is not None:
                entry = output.get("entry")
                database["def_origin"][-1] += "%s" % (entry)
            else:
                database["def_origin"].append("")
            if output == {}:
                continue

            if "num_comps" in output.keys():
                database.update(output)
            elif "key" in output.keys():
                database["design_config"][output["key"]] = eval(output["value"]) 

            elif "lower_x" in output.keys():
                database["diearea_rect"].extend([eval(output['lower_x']), 
                                                eval(output['lower_y']),
                                                eval(output['upper_x']),
                                                eval(output['upper_y'])])
            elif "row_y" in output.keys():
                database["row"][output["id"]] = {}
                database["row"][output["id"]]["row_x"] = int(output["row_x"])
                database["row"][output["id"]]["row_y"] = int(output["row_y"])
                database["row"][output["id"]]["after_do"] = int(output["after_do"])
                database["row"][output["id"]]["after_by"] = int(output["after_by"])
                database["row"][output["id"]]["step_x"] = int(output["step_x"])
                database["row"][output["id"]]["step_y"] = int(output["step_y"])
            else:
                head = output.get("head", None)
                if head == '-':
                    node_name = output["node_name"]
                    database["nodes"][node_name] = deepcopy(output)
                elif head == '+':
                    state = output["state"]
                    if state == "FIXED":
                        database["macros"].append(node_name)
                    database["nodes"][node_name].update(output)
                else:
                    continue   
    
def get_scaling_ratio(database):
    design_range_x = database["design_config"]["FE_CORE_BOX_UR_X"] - database["design_config"]["FE_CORE_BOX_LL_X"]
    design_range_y = database["design_config"]["FE_CORE_BOX_UR_Y"] - database["design_config"]["FE_CORE_BOX_LL_Y"]
    diearea_range_x = database["diearea_rect"][2] - database["diearea_rect"][0]
    diearea_range_y = database["diearea_rect"][3] - database["diearea_rect"][1]

    ratio_x = design_range_x / diearea_range_x
    ratio_y = design_range_y / diearea_range_y

    return ratio_x, ratio_y

def get_inv_scaling_ratio(database):
    design_range_x = database["design_config"]["FE_CORE_BOX_UR_X"] - database["design_config"]["FE_CORE_BOX_LL_X"]
    design_range_y = database["design_config"]["FE_CORE_BOX_UR_Y"] - database["design_config"]["FE_CORE_BOX_LL_Y"]
    diearea_range_x = database["diearea_rect"][2] - database["diearea_rect"][0]
    diearea_range_y = database["diearea_rect"][3] - database["diearea_rect"][1]

    ratio_x = diearea_range_x / design_range_x
    ratio_y = diearea_range_y / design_range_y

    return ratio_x, ratio_y
        

def write_def(macro_pos, database, def_file, ratio_x, ratio_y, limited_x=None, limited_y=None, is_dataset=False):
    def_origin = list(reversed(database["def_origin"]))
    content = def_origin.pop()

    content += f"DIEAREA ( 0 0 ) ( {limited_x} {limited_y} ) ;\n"
    content += def_origin.pop()

    delta =  database["row"]['2']["row_y"] - database["row"]['1']["row_y"]
    row_id_lst = sorted(database["row"].keys(), key=lambda x:int(x))

    for row_id in row_id_lst:
        after_do = min(database["row"][row_id]["after_do"], (limited_x - database["row"][row_id]["row_x"]) // database["row"][row_id]["step_x"])
        if database["row"][row_id]["row_y"] + delta <= limited_y:
            content += f"ROW coreROW_{row_id} core {database['row'][row_id]['row_x']} " + \
                       f"{database['row'][row_id]['row_y']} N DO {after_do} " + \
                       f"BY {database['row'][row_id]['after_by']} STEP {database['row'][row_id]['step_x']} " + \
                       f"{database['row'][row_id]['step_y']} ;\n"
        content += def_origin.pop()
    
    content += "COMPONENTS %s ;\n" % (database["num_comps"])

    node_list = database["nodes"].keys()
    inv_ratio_x, inv_ratio_y = get_inv_scaling_ratio(database)
    for node_name in node_list:
        node_info = database["nodes"][node_name]
        content += \
            "- %(node_name)s %(node_type)s\n" % node_info
        if node_name in macro_pos.keys():
            x, y, _, _ = macro_pos[node_name]
            x = round(x * ratio_x + ratio_x) 
            y = round(y * ratio_y + ratio_y)

            # inv scaling
            x *= inv_ratio_x
            y *= inv_ratio_y
            
            if is_dataset:
                content += \
                    f"\t+ PLACED ( {x} {y} ) {node_info['dir']} ;\n"
            else:
                content += \
                    f"\t+ FIXED ( {x} {y} ) {node_info['dir']} ;\n"
        else:
            content += \
                "\t+ PLACED ( %(x)s %(y)s ) %(dir)s ;\n" % node_info
            
    content += "END COMPONENTS\n"

    while len(def_origin) > 0:
        content += def_origin.pop()


    with open(def_file, "w") as f:
        f.write(content)

    
    
    
    

