import torch
import numpy as np
import pickle, json, time, re, sys
from DG import Graph, Node
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl


def node_feat_extra_net(node_name, node_class:Node, g_nx:nx.DiGraph, node_dict):
    ## 1. fanin
    fanin_iter = g_nx.predecessors(node_name)
    fanin = sum(1 for _ in fanin_iter)

    ## 2. fanout
    fanout_iter = g_nx.successors(node_name)
    fanout = sum(1 for _ in fanout_iter)
    ## 3. node type
    node_tpe_ori = node_class.tpe

    total_num = 19

    node_type_cp = [0 for i in range(total_num)]

    if node_tpe_ori in ['DFF']:
        node_type = node_type_cp.copy()
        node_type[0] = 1
    elif node_tpe_ori in ['Input', 'Inout', None]:
        node_type = node_type_cp.copy()
        node_type[1] = 1
    elif node_tpe_ori in ['Output']:
        node_type = node_type_cp.copy()
        node_type[2] = 1
    elif node_tpe_ori in ['Wire']:
        node_type = node_type_cp.copy()
        node_type[3] = 1
    elif node_tpe_ori in ['Const']:
        node_type = node_type_cp.copy()
        node_type[4] = 1
    elif node_tpe_ori in ['BUF']:
        node_type = node_type_cp.copy()
        node_type[5] = 1
    elif node_tpe_ori in ['INV']:
        node_type = node_type_cp.copy()
        node_type[6] = 1
    elif node_tpe_ori in ['AND']:
        node_type = node_type_cp.copy()
        node_type[7] = 1
    elif node_tpe_ori in ['OR']:
        node_type = node_type_cp.copy()
        node_type[8] = 1
    elif node_tpe_ori in ['XOR']:
        node_type = node_type_cp.copy()
        node_type[9] = 1
    elif node_tpe_ori in ['NOR']:
        node_type = node_type_cp.copy()
        node_type[10] = 1
    elif node_tpe_ori in ['XNOR']:
        node_type = node_type_cp.copy()
        node_type[11] = 1
    elif node_tpe_ori in ['NAND']:
        node_type = node_type_cp.copy()
        node_type[12] = 1
    elif node_tpe_ori in ['AOI']:
        node_type = node_type_cp.copy()
        node_type[13] = 1
    elif node_tpe_ori in ['OAI']:
        node_type = node_type_cp.copy()
        node_type[14] = 1
    elif node_tpe_ori in ['MUX']:
        node_type = node_type_cp.copy()
        node_type[15] = 1
    elif node_tpe_ori in ['DLL']:
        node_type = node_type_cp.copy()
        node_type[16] = 1
    elif node_tpe_ori in ['HA']:
        node_type = node_type_cp.copy()
        node_type[17] = 1
    elif node_tpe_ori in ['FA']:
        node_type = node_type_cp.copy()
        node_type[18] = 1
    else:
        print(node_tpe_ori)
        assert False
    

    # 4. node strength
    node_class = node_dict[node_name]
    node_strength = node_class.strength
    
    if not node_strength:
        node_strength = float(1)
    pwr = node_class.pwr
    area = node_class.area
    load = node_class.load
    prob = node_class.prob
    tr = node_class.tr
    cap = node_class.cap
    res = node_class.res

    feat_vec = [fanin, fanout, fanin+fanout, node_strength, pwr, area, cap, res]
    feat_vec.extend(node_type)


    feat_vec = np.array(feat_vec)

    return feat_vec, [prob, tr]



