import copy, sys, time, json
from DG import *
import networkx as nx
import re

# sys.setrecursionlimit(10000)

class AST_analyzer(object):
    def __init__(self, ast):
        self.__ast = ast
        self.graph = Graph()
        self.oper_label = 0
        self.const_label = 0
        self.wire_set = set()

        self.wire_dict = {}
        self.temp_dict = {}

        self.func_set = set()
        self.func_dict = {}

    def AST2Graph(self, ast, cmd, design_name):
        self.traverse_AST(ast)
        self.graph.cal_node_width()
        self.update_ppa_info(cmd, design_name)
        self.eliminate_wires(self.graph)
        # self.add_parent_edge()

       
    def traverse_AST(self, ast):
        node_type = ast.get_type()
        self.add_decl_node(ast, node_type)
        self.add_instance(ast, node_type)
        
        for c in ast.children():
            self.traverse_AST(c)
    
  
    def add_decl_node(self, ast, node_type):
        if node_type == 'Decl':
            ll = len(ast.children())
            if ll == 1:
                child = ast.children()[0]
                child_type = child.get_type()
                name = child.name
                width = self.get_width(child)
                self.graph.add_decl_node(name, child_type, width, None, child_type)
                if child_type == 'Wire':
                    self.wire_set.add(name)
            
            elif ll >= 2:
                for child in ast.children():
                    child_type = child.get_type()
                    name = child.name
                    width = self.get_width(child)
                    self.graph.add_decl_node(name, child_type, width, None, child_type)
                    if child_type == 'Wire':
                        self.wire_set.add(name)
            
            else:
                print(ll)
                print(ast)
                print(ast.children())
                print(ast.children()[1].name)
                assert False

    def add_instance(self, ast, node_type):
        if node_type == 'InstanceList':
            assert len(ast.instances) == 1
            inst = ast.instances[0]
            inst_name = inst.name
            inst_module = inst.module
            inst_tpe, strength = self.convert_inst_type(inst_module)
            self.graph.add_decl_node(inst_name, inst_module, 1, None, inst_tpe, strength)
            if not inst_tpe:
                return
            for port_arg in inst.portlist:
                port_name = port_arg.portname
                if not port_arg.argname:
                    continue
                port_node = self.add_new_node(port_arg.argname)
                assert port_node in self.graph.node_dict.keys()
                direc = self.convert_port_direction(port_name, inst_tpe)
                if direc == 'i':
                    self.graph.add_edge(inst_name, port_node)
                elif direc == 'o':
                    self.graph.add_edge(port_node, inst_name)
            # print(self.graph.graph)
            # print(11)

    def convert_inst_type(self, inst_tpe):
        if re.search(r'^(S)*DFF(R)*(S)*(\d)*_X(\d+)', inst_tpe):
        # if 'DFF' in inst_tpe:
            ret_tpe = 'DFF'
        elif re.search(r'^INV(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'INV'
        elif re.search(r'^BUF(\d)*_X(\d+)|^CLKBUF(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'BUF'
        elif re.search(r'^XOR(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'XOR'
        elif re.search(r'^AOI(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'AOI'
        elif re.search(r'^OAI(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'OAI'
        elif re.search(r'^OR(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'OR'
        elif re.search(r'^NAND(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'NAND'
        elif re.search(r'^AND(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'AND'
        elif re.search(r'^MUX(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'MUX'
        elif re.search(r'^NOR(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'NOR'
        elif re.search(r'^XNOR(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'XNOR'
        elif re.search(r'^HA(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'HA'
        elif re.search(r'^FA(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'FA'
        elif re.search(r'^DLL(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'DLL'
        elif re.search(r'^TBUF(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'BUF'
        elif re.search(r'^TINV(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'INV'
        elif re.search(r'^DLH(\d)*_X(\d+)', inst_tpe):
            ret_tpe = 'DLH'
        elif re.search(r'^([A-Z])+(\d)*_X(\d+)', inst_tpe):
            print(inst_tpe)
            assert False

        # elif inst_tpe in ['IBuf_0']:
        #     ret_tpe = ''
        else:
            print(inst_tpe)
            ret_tpe = ''
            input()
            # assert False


        strength = 0
        strength_re = re.findall(r"(\S+)_X(\d+)", inst_tpe)
        if strength_re:
            strength = int(strength_re[0][-1])

        return ret_tpe, strength  

    def convert_port_direction(self, port_name, inst_tpe=None):
            
        if inst_tpe in ['HA', 'FA']:
            if port_name in ['A', 'B', 'CI']:
                return 'i'
            elif port_name in ['S', 'CO']:
                return 'o'
            else:
                print(port_name)
                assert False
        else:
            if port_name in ['D', 'CK', 
                            'A', 'A1', 'A2', 'A3', 'A4',\
                            'B', 'B1', 'B2', 'B3', 'B4',\
                            'C1', 'C2',
                            'RN', 'SI', 'SE', 'S', 'SN',\
                            'GN', 'EN', 'I', 'G']:
                return 'i'
            elif port_name in ['Q', 'QN', 'Z', 'ZN']:
                return 'o'
            else:
                print(port_name)
                assert False
    
    def add_parent_edge(self):
        for name, node in self.graph.node_dict.items():
            if node.father:
                # if self.graph.node_dict[node.father].type == 'Reg':
                    self.graph.add_edge(node.father, name)

    def cal_width(self, ast):
        msb = int(ast.msb.value)
        lsb = int(ast.lsb.value)
        LHS = max(msb, lsb)
        RHS = min(msb, lsb)
        width = LHS - RHS + 1
        return width

    def get_width(self, ast): # -> int
        width = ast.width
        dimens = ast.dimensions
        if width:
            width = self.cal_width(width)
        else:
            width = 1
        if dimens:
            length = dimens.lengths[0]
            length = self.cal_width(length)
        else:
            length = 1
        return width*length



    def get_node_width(self, ast):
        node_type = ast.get_type()
        parent_type = ast.get_parent_type()
        
        if node_type == 'Identifier':
            width = self.graph.node_dict[ast.name].width

        elif node_type == 'Pointer':
            width = 1
        elif node_type == 'Partselect':
            self.add_new_node(ast)
            width = self.graph.node_dict[ast.var.name].width
        elif node_type == 'IntConst':
            width = self.get_width_num(ast.value)
        elif node_type in ['Concat']:
            width = None
        elif parent_type == 'UnaryOperator':
            width = self.get_node_width(ast.right)
        else:
            print(node_type)
            assert False

        return width


            
    def add_new_node(self, ast):
        node_type = ast.get_type()
        parent_type = ast.get_parent_type()
        if node_type == 'Identifier':
            node_name = ast.name
            assert node_name in self.graph.node_dict.keys()
        elif node_type == 'Pointer':
            name = ast.var.name
            ptr = ast.ptr.value
            node_name = name + '_reg_' + ptr + '_'
            if node_name not in self.graph.node_dict.keys():
                self.graph.add_decl_node(node_name, 'Pointer', 1, name)
        elif node_type == 'Partselect':
            name = ast.var.name
            if (ast.msb.get_type() != 'IntConst' or ast.msb.get_type() != 'IntConst'):
                node_name = name
            else:
                msb = ast.msb.value
                lsb = ast.lsb.value
                width = self.cal_width(ast)
                node_name = name + '.PS' + msb + '_' + lsb
                if node_name not in self.graph.node_dict.keys():
                    self.graph.add_decl_node(node_name, 'Partselect', width, name)
        elif node_type == 'IntConst':
            node_name = 'Const'
            self.graph.add_decl_node(node_name, 'Const', 1, None, 'Const')
        else:
            print(node_type)
            assert False
        return node_name
    
   

    def get_width_num(self, num):
        is_string = re.findall(r"[a-zA-Z]+\'*[a-z]* |'?'*", num)
        if num in ['0', '1']:
            width = 1    
        elif '\'' in num:
            width = re.findall(r"(\d+)'(\w+)", num)
            width = int(width[0][0])
        elif is_string:
            width = len(num)
        else:
            print('ERROR: New Situation!')
            print(num)
            width = 0
            print(is_string)
            assert False
        
        return width

    
    def eliminate_wires(self, g:Graph):
        print('----- Eliminating Wires in Graph -----')
        for name, node in self.graph.node_dict.items():
            if node.father in self.wire_set:
                # print(name)
                # input()
                self.wire_set.add(name)
        g_node = g.get_all_nodes2()
        interset = g_node & self.wire_set
        ll = len(interset)
        while(len(interset)!=0):
            pre_len = len(interset)
            g = self.eliminate_wire(g)
            g_node = g.get_all_nodes2()
            interset = g_node & self.wire_set
            post_len = len(interset)
            if pre_len == post_len:
                break
        if len(interset) != 0:
            print('Warning: uneliminated wire: ', len(interset))
            for n in interset.copy():
                neighbor = self.graph.get_neighbors(n)
                if len(neighbor) == 0:
                    self.graph.remove_node(n)
                    interset.remove(n)

            print('Final uneliminated wire: ', len(interset))
        else:
            print('Finish!\n')
        node_dict = self.graph.node_dict.copy()
        self.graph = g
        self.graph.load_node_dict(node_dict)

    def eliminate_wire(self, g:Graph):
        node_set = g.get_all_nodes()
        for node in node_set:
            node_list = g.get_neighbors(node)
            if node in self.wire_set:
                self.wire_dict[node] = node_list
            else:
                self.temp_dict[node] = node_list
        g_new = Graph()
        for node, node_list in self.temp_dict.items():
            for n in node_list:
                if n in self.wire_dict.keys():
                    wire_assign = self.wire_dict[n]
                    for w in wire_assign:
                        if w:
                            g_new.add_edge(node, w)
                else:
                    g_new.add_edge(node, n)
        return g_new

    def graph2dff_lst(self, design_name):
        
        g_nx = nx.DiGraph(self.graph.graph)
        node_dict = self.graph.node_dict
        dff_set = set()
        for name, node in node_dict.items():
            # print(node.tpe)
            if node.tpe == 'DFF':
                dff_set.add(name)
        
        return self.graph.graph, self.graph.node_dict ,dff_set

    
    def update_ppa_info(self, cmd, design_name):
        pt_info_dir = f"/home/coguest5/rtl_repr/data_collect/pt_info/{cmd}/{design_name}"
        with open (f"{pt_info_dir}/cell.json", 'r') as f:
            cell_dict_all = json.load(f)
        with open (f"{pt_info_dir}/net.json", 'r') as f:
            net_dict_all = json.load(f)
        with open (f"{pt_info_dir}/net_delay.json", 'r') as f:
            net_delay_dict_all = json.load(f)

        g = self.graph
        g_nx = nx.DiGraph(g.graph)
        node_dict = g.node_dict
        dict_all = {}
        dict_all.update(cell_dict_all)
        dict_all.update(net_dict_all)
        dict_all.update(net_delay_dict_all)

        for name, _ in node_dict.copy().items():

            if name in cell_dict_all:
                cell_dict = cell_dict_all[name]

                scale = 1000000
                node = self.graph.node_dict[name]
                node.inter_pwr = round(cell_dict['inter_pwr']*scale,3)
                node.swith_pwr = round(cell_dict['switch_pwr']*scale,3)
                node.leak_pwr = round(cell_dict['leak_pwr']*scale,3)
                node.pwr = round(cell_dict['cell_pwr']*scale,3)
                node.area = round(cell_dict['cell_area'],3)
                self.graph.node_dict[name] = node

            if name in net_dict_all:
                net_dict = net_dict_all[name]
                if name not in g_nx:
                    continue

                for n in g_nx.successors(name):
                    node = self.graph.node_dict[n]
                    node.load = round(net_dict['net_load']*100,3)
                    node.tr = round(net_dict['net_tr'],3)
                    node.prob = round(net_dict['net_prob'],3)
                    self.graph.node_dict[name] = node
            
            if name in net_delay_dict_all:
                net_delay_dict = net_delay_dict_all[name]
                if name not in g_nx:
                    continue

                for n in g_nx.successors(name):
                    node = self.graph.node_dict[n]
                    node.cap = round(net_delay_dict['cap'],3)
                    node.res = round(net_delay_dict['res']*100,3)
                    self.graph.node_dict[name] = node

        # for name, node in self.graph.node_dict.items():
        #     if name not in self.graph.graph:
        #         continue
        #     print(node)
        # input()