import os, gzip, re, bisect
import dgl
import dgl.function as fn
from dgl.transforms import add_self_loop
import numpy as np
import torch
from scipy import ndimage
import cv2
import tqdm
from torch.utils.data import DataLoader, Dataset
import pickle
import binascii

def copy_src(src, out):
    return fn.copy_u(src, out)

fn.copy_src = copy_src
def get_unit(path):
    with open(path, 'r') as read_file:
        read_unit = 0
        for line in read_file:
            if line.lstrip().startswith('UNITS'):
                read_unit = 1
            if read_unit:
                if line.lstrip().startswith('DATABASE'):
                    unit = line.split()[2]
                    #break
            if line == 'END UNITS':
                read_unit = 0
                break
    return int(unit)

def fo_average(g):
    degrees = g.out_degrees(g.nodes()).type(torch.float32)
    g.ndata['addnlfeat'] = (g.ndata['feat']) / degrees.view(-1, 1)
    g.ndata['inter'] = torch.zeros_like(g.ndata['feat'])
    g.ndata['wts'] = torch.ones(g.number_of_nodes()) / degrees
    g.ndata['wtmsg'] = torch.zeros_like(g.ndata['wts'])
    g.update_all(message_func=fn.copy_src(src='addnlfeat', out='inter'),
                 reduce_func=fn.sum(msg='inter', out='addnlfeat'))
    g.update_all(message_func=fn.copy_src(src='wts', out='wtmsg'),
                 reduce_func=fn.sum(msg='wtmsg', out='wts'))
    hop1 = g.ndata['addnlfeat'] / (g.ndata['wts'].view(-1, 1))
    return hop1

def get_partition_list_random(g, p_size):
    nids = g.nodes()
    nids = np.random.permutation(nids)
    return [nids[i::p_size] for i in range(p_size)]

def node_pairs_among(nodes, max_cap=-1):
    us = []
    vs = []
    if max_cap == -1 or len(nodes) <= max_cap:
        for u in nodes:
            for v in nodes:
                if u == v:
                    continue
                us.append(u)
                vs.append(v)
    else:
        for u in nodes:
            vs_ = np.random.permutation(nodes)
            left = max_cap - 1
            for v_ in vs_:
                if left == 0:
                    break
                if u == v_:
                    continue
                us.append(u)
                vs.append(v_)
                left -= 1
    return us, vs

def is_gzip_file(file_path):
    with open(file_path, 'rb') as file:
        header = file.read(2)
    hex_header = binascii.hexlify(header).decode('utf-8')
    if hex_header == '1f8b':
        return True
    else:
        return False


def instance_direction_rect(line):  # used when we only need bounding box (rect) of the cell.

    if 'N' in line or 'S' in line:
        m_direction = (1, 0, 0, 1)
    elif 'W' in line or 'E' in line:
        m_direction = (0, 1, 1, 0)
    else:
        raise ValueError('read_macro_direction_wrong')
    return m_direction

def instance_direction_bottom_left(direction):  # used when we need to get the bottom left corner of the cell.
    if direction == 'N':
        i_direction = (0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0)
    elif direction == 'S':
        i_direction = (-1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 1)
    elif direction == 'W':
        i_direction = (0, -1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0)
    elif direction == 'E':
        i_direction = (0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 1, 0)
    elif direction == 'FN':
        i_direction = (-1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0)
    elif direction == 'FS':
        i_direction = (0, 0, 0, -1, 1, 0, 0, 0, 0, 0, 0, 1)
    elif direction == 'FW':
        i_direction = (0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0)
    elif direction == 'FE':
        i_direction = (0, -1, -1, 0, 0, 0, 0, 0, 0, 1, 1, 0)
    else:
        raise ValueError('read_macro_direction_wrong')
    return i_direction
def resize(input):
    dimension = input.shape
    result = ndimage.zoom(input, (256 / dimension[0], 256 / dimension[1]), order=3)
    return result

def std(input):
    if input.max() == 0:
        return input
    else:
        result = (input-input.min()) / (input.max()-input.min())
        return result

def resize_cv2(input):
    output = cv2.resize(input, (256, 256), interpolation = cv2.INTER_AREA)
    return output

class NetlistgnnDataset(Dataset):
    def __init__(self, data_root, namelist, args):

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.data_root = data_root
        self.save_graph_root = args.save_graph_root
        if args.pretrain:
            self.save_graph_root += "0"
        else:
            self.save_graph_root += "1"
        os.makedirs(os.path.join(self.data_root,self.save_graph_root), exist_ok=True)
        self.name_list = []
        list_file = open(os.path.join(data_root, namelist), mode='r')
        for line in list_file:
            if line.endswith('\n'):
                line = line[:-1]
            self.name_list.append(line)
        list_file.close()
        # 判断是否有图，如果没有，则生成图
        self.graphs = []
        self.get_circuit_features(args)
        for i in range(len(self.name_list)):
            name = self.name_list[i]
            self.graph_pickle = os.path.join(self.data_root,self.save_graph_root, name.split('.def')[0] + '.pickle')
            if os.path.exists(self.graph_pickle):
                with open(self.graph_pickle, 'rb') as fp:
                    g = pickle.load(fp)
                self.graphs.append(g)
            else:
                self.use_tqdm = 1
                self.graph_scale = 10000
                self.die_area = None  # Total area of die (chip). Currently not used.
                self.gcell_size = [-1, -1]  # Number of Gcell grids.
                self.gcell_coordinate_x = []  # Coordinate of Gcell grids in DEF.
                self.gcell_coordinate_y = []
                self.route_instance_dict = {}  # contains information of instances {inst_name:[std_cell_name, [inst_location_w/h], inst_direction]} 包含实例的信息 {inst_name：[std_cell_name， [inst_location_w/h]， inst_direction]}
                self.route_net_dict = {}  # contains instance names and pin names in each nets {net_name:[pin_names]}# 包含每个网络中的实例名称和引脚名称 {net_name：[pin_names]}
                self.route_pin_dict = {}  # contains information of IO pins {pin_name{'layer':layer, 'rect':rect, 'location':location, 'direction':direction}}
                self.place_instance_dict = {}  # same as route
                self.place_net_dict = {}
                self.place_pin_dict = {}
                self.lef_dict = {}
                # 节点特有特征信息
                self.sizedata_x = []
                self.sizedata_y = []
                self.pdata = []
                self.xdata = []
                self.ydata = []
                self.xsdata = []
                self.ysdata = []
                # net特征信息
                self.span_v = []
                self.span_h = []
                self.span_pv = []
                self.span_ph = []
                self.n_pin = []
                #其他特征
                self.direction = []

                self.lef_path = os.path.join(self.data_root, args.lef_path)
                self.place_def_path = os.path.join(self.data_root, args.place_def_root, name+'.def.gz')
                self.route_def_path = os.path.join(self.data_root, args.route_def_root, name+'.def.gz')
                self.read_lef()
                self.read_route_def()
                self.read_place_def()
                self.generate_graph(i)
                with open(self.graph_pickle, 'rb') as fp:
                    g = pickle.load(fp)
                self.graphs.append(g)




    def read_lef(self):
        unit = get_unit(self.lef_path)
        self.unit = unit
        with open(self.lef_path, 'r') as read_file:
            cell_name = ''
            pin_name = ''
            rect_list_left = []
            rect_list_lower = []
            rect_list_right = []
            rect_list_upper = []
            READ_MACRO = False
            # index = 0
            for line in read_file:
                if line.lstrip().startswith('MACRO'):
                    READ_MACRO = True
                    cell_name = line.split()[1]

                    self.lef_dict[cell_name] = {}
                    self.lef_dict[cell_name]['pin'] = {}
                    # lef_dict[cell_name]['index'] = index
                    # index += 1

                if READ_MACRO:
                    if line.lstrip().startswith('SIZE'):
                        l = re.findall(r'-?\d+\.?\d*e?-?\d*?', line)
                        self.lef_dict[cell_name]['size'] = [unit * float(l[0]), unit * float(l[1])]  # size [unit*w,unit*h]

                    elif line.lstrip().startswith('PIN'):
                        pin_name = line.split()[1]


                    elif line.lstrip().startswith('RECT'):
                        l = line.split()
                        rect_list_left.append(float(l[1]))
                        rect_list_lower.append(float(l[2]))
                        rect_list_right.append(float(l[3]))
                        rect_list_upper.append(float(l[4]))
                    elif line.lstrip().startswith('DIRECTION'):
                        l = line.split()
                        if l[1] == 'OUTPUT':
                            f = 0
                        else:
                            f = 1
                        self.lef_dict[cell_name]['pin'][pin_name] = f

                    elif line.lstrip().startswith('END %s\n' % pin_name):
                        rect_left = min(rect_list_left) * unit
                        rect_lower = min(rect_list_lower) * unit
                        rect_right = max(rect_list_right) * unit
                        rect_upper = max(rect_list_upper) * unit
                        f = self.lef_dict[cell_name]['pin'][pin_name]
                        self.lef_dict[cell_name]['pin'][pin_name] = [rect_left, rect_lower, rect_right, rect_upper,
                                                                f]  # pin_rect
                        rect_list_left = []
                        rect_list_lower = []
                        rect_list_right = []
                        rect_list_upper = []
    def read_route_def(self):
        self.gcell_size= [-1,-1]
        self.gcell_coordinate_x = []
        self.gcell_coordinate_y = []
        GCELLX = []
        GCELLY = []
        if is_gzip_file(self.route_def_path):
            read_file = gzip.open(self.route_def_path,"rt")
        else:
            read_file = open(self.route_def_path,"r")

        READ_GCELL = False
        READ_MACROS = False
        READ_NETS = False
        READ_PINS = False
        net = ''
        for line in read_file:
            line = line.lstrip()
            #if line.startswith("COMPONENTS"):
            #    READ_MACROS = True
            #elif line.startswith("END COMPONENTS"):
            #    READ_MACROS = False
            #elif line.startswith("NETS"):
            #    READ_NETS =True
            #elif line.startswith("END NETS") or line.startswith("SPECIALNETS"):
            #    READ_NETS = False
            #elif line.startswith('PIN'):
            #    READ_PINS =True
            #elif line.startswith('END PINS'):
            #    READ_PINS = False
            if line.startswith("GCELLGRID"):
                READ_GCELL = True
            elif line.startswith("") and READ_GCELL == True:
                READ_GCELL = False
                if len(GCELLX) <= 2:
                    raise ValueError
                if int(GCELLX[0][0]) < int(GCELLX[-1][0]):
                    GCELLX.reverse()
                    GCELLY.reverse()

                top = GCELLY.pop()
                for i in range(top[1]-1):
                    self.gcell_coordinate_y.append(top[0]+(i+1)*top[2])
                for i in range(len(GCELLY)):
                    top = GCELLY.pop()
                    for i in range(top[1]):
                        self.gcell_coordinate_y.append(self.gcell_coordinate_y[-1]+top[2])
                self.gcell_coordinate_y = np.array(self.gcell_coordinate_y)

                top = GCELLX.pop()
                for i in range(top[1]-1):
                    self.gcell_coordinate_x.append(top[0]+(i+1)*top[2])
                for i in range(len(GCELLX)):
                    top = GCELLX.pop()
                    for i in range(top[1]):
                        self.gcell_coordinate_x.append(self.gcell_coordinate_x[-1]+top[2])
                self.gcell_coordinate_x = np.array(self.gcell_coordinate_x)
                return
            if READ_GCELL:   # get gcell_coordinate
                instance = line.split()
                if not len(instance) == 8:
                    continue
                else:
                    gcell = [int(int(instance[2])),int(instance[4]),int(int(instance[6]))]  # at x do y step z
                    if 'Y' in line:
                        self.gcell_size[1] += int(instance[4])
                        GCELLY.append(gcell)

                    elif 'X' in line:
                        self.gcell_size[0] += int(instance[4])
                        GCELLX.append(gcell)

    def read_place_def(self):
        if is_gzip_file(self.place_def_path):
            read_file = gzip.open(self.place_def_path, "rt")
        else:
            read_file = open(self.place_def_path, "r")
        self.componentname2indexmap = {}
        self.netname2indexmap = {}
        READ_MACROS = False
        READ_NETS = False
        READ_PINS = False
        macro_map = np.zeros(self.gcell_size)
        macro_map_with_halo = np.zeros(self.gcell_size)
        net = ''
        netindex = 0
        index = 0
        for line in read_file:
            line = line.lstrip()
            if line.startswith("DIEAREA"):
                die_coordinate = re.findall(r'\d+', line)
                self.die_area = (int(die_coordinate[2]), int(die_coordinate[3]))
            elif line.startswith("COMPONENTS"):
                READ_MACROS = True
            elif line.startswith("END COMPONENTS"):
                READ_MACROS = False
                component_num = len(self.sizedata_x)
                self.pdata = [0] * component_num
            elif line.startswith("NETS"):
                READ_NETS = True
            elif line.startswith("END NETS") or line.startswith("SPECIALNETS"):
                READ_NETS = False
            elif line.startswith('PIN'):
                READ_PINS = True
            elif line.startswith('END PINS'):
                READ_PINS = False

            if READ_MACROS:  # get macro_region (routability feature)
                if "FIXED" in line or "PLACED" in line:
                    instance = line.split()
                    l = instance.index('(')
                    coordinate = (int(instance[l + 1]), int(instance[l + 2]))
                    self.xdata.append(int(instance[l + 1]))
                    self.ydata.append(int(instance[l + 2]))
                    self.xsdata.append(bisect.bisect_left(self.gcell_coordinate_x, int(instance[l + 1])))
                    self.ysdata.append(bisect.bisect_left(self.gcell_coordinate_y, int(instance[l + 2])))
                    cell_size = self.lef_dict[instance[2]]['size']
                    direction = instance_direction_rect(instance[l + 4])
                    sx = cell_size[0] * direction[0] + cell_size[1] * direction[1]  # 注意 这里只有大小的一半？
                    sy = cell_size[0] * direction[2] + cell_size[1] * direction[3]
                    self.sizedata_x.append(sx)
                    self.sizedata_y.append(sy)
                    # self.pdata.append(len(self.lef_dict[instance[2]]['pin']))
                    self.direction.append(instance[l + 4])
                    self.componentname2indexmap[instance[1]] = index
                    # self.nodename2indexmap[instance[1]] = index
                    index += 1

                    self.place_instance_dict[instance[1]] = instance[2]
                    # x_y_lower_left_gcell, x_y_upper_right_gcell, x_y_lower_left, x_y_upper_right = self.get_macro_region(line, coordinate, self.lef_dict[instance[2]]['size'])
                    # macro_map[x_y_lower_left_gcell[0]:x_y_upper_right_gcell[0], x_y_lower_left_gcell[1]:x_y_upper_right_gcell[1]] = np.ones([x_y_upper_right_gcell[0] - x_y_lower_left_gcell[0], x_y_upper_right_gcell[1] - x_y_lower_left_gcell[1]])
                '''
                elif line.lstrip().startswith('+ HALO'):
                    l = line.split()
                    halo = [int(l[2]), int(l[3]), int(l[4]), int(l[5])]
                    x_y_lower_left = [x_y_lower_left[0] - halo[0], x_y_lower_left[1] - halo[1]]
                    x_y_upper_right = [x_y_upper_right[0] + halo[2], x_y_upper_right[1] + halo[3]]
                    x_y_lower_left_gcell[0] = bisect.bisect_left(self.gcell_coordinate_x, x_y_lower_left[0])    # map to gcell grids
                    x_y_lower_left_gcell[1] = bisect.bisect_left(self.gcell_coordinate_y, x_y_lower_left[1])   
                    x_y_upper_right_gcell[0] = bisect.bisect_left(self.gcell_coordinate_x, x_y_upper_right[0])
                    x_y_upper_right_gcell[1] = bisect.bisect_left(self.gcell_coordinate_y, x_y_upper_right[1])
                    macro_map_with_halo[x_y_lower_left_gcell[0]:x_y_upper_right_gcell[0], x_y_lower_left_gcell[1]:x_y_upper_right_gcell[1]] = np.ones([x_y_upper_right_gcell[0] - x_y_lower_left_gcell[0], x_y_upper_right_gcell[1] - x_y_lower_left_gcell[1]])
                elif "PLACED" in line:                              # get place_instance_dict             
                    instance = line.split()
                    l = instance.index('(')
                    if instance[1]=='(':
                        print(line)
                    self.place_instance_dict[instance[1]] = [instance[2], (int(instance[l+1]), int(instance[l+2])), instance[l+4]]
                '''
            if READ_NETS:  # get route_net_dict
                if line.startswith('-'):
                    net = netindex
                    netindex += 1
                    self.place_net_dict[net] = []

                elif line.startswith('('):  # get pin names in each net
                    l = line.split()
                    n = 0
                    for k in l:
                        if k == '(':
                            component_name = l[n + 1]
                            pin_name = l[n + 2]
                            if component_name not in self.componentname2indexmap.keys():
                                continue
                            # node_name = self.componentname2indexmap[component_name]
                            component_index = self.componentname2indexmap[component_name]
                            std_cell_x = self.sizedata_x[component_index]
                            std_cell_y = self.sizedata_y[component_index]
                            self.pdata[component_index] += 1
                            direction = instance_direction_bottom_left(self.direction[component_index])
                            pin_px = direction[8] * std_cell_x + direction[9] * std_cell_y
                            pin_py = direction[10] * std_cell_x + direction[11] * std_cell_y
                            # convert coordinate to lower left corner

                            # pin_location_on_instance left/lower/right/upper
                            pin_location_on_instance = self.lef_dict[self.place_instance_dict[component_name]]['pin'][
                                pin_name]

                            pin_left = pin_px + pin_location_on_instance[0] * direction[4] + \
                                       pin_location_on_instance[1] * direction[5] + pin_location_on_instance[2] * \
                                       direction[0] + pin_location_on_instance[3] * direction[1]
                            pin_lower = pin_py + pin_location_on_instance[0] * direction[6] + \
                                        pin_location_on_instance[1] * direction[7] + pin_location_on_instance[2] * \
                                        direction[2] + pin_location_on_instance[3] * direction[3]
                            pin_right = pin_px + pin_location_on_instance[2] * direction[4] + \
                                        pin_location_on_instance[3] * direction[5] + pin_location_on_instance[0] * \
                                        direction[0] + pin_location_on_instance[1] * direction[1]
                            pin_upper = pin_py + pin_location_on_instance[2] * direction[6] + \
                                        pin_location_on_instance[3] * direction[7] + pin_location_on_instance[0] * \
                                        direction[2] + pin_location_on_instance[1] * direction[3]

                            self.place_net_dict[net].append(
                                [component_index, pin_left / self.unit, pin_lower / self.unit, pin_right / self.unit,
                                 pin_upper / self.unit, pin_location_on_instance[4]])
                            # self.place_net_dict[net].append([l[n+1], l[n+2]])
                        n += 1
            if READ_PINS:  # get place_pin_dict (for primary IO pins)
                if line.startswith('-'):
                    pin = line.split()[1]
                elif line.strip().startswith('+ LAYER'):
                    pin_rect = re.findall(r'\d+', line)
                    self.place_pin_dict[pin] = [int(pin_rect[-4]), int(pin_rect[-3]), int(pin_rect[-2]),
                                                int(pin_rect[-1])]
        # self.macro_map = macro_map
        # self.macro_map_with_halo = macro_map_with_halo
        # save(self.save_path, 'macro_region', self.save_name, self.macro_map_with_halo) # save macro_region feature
        read_file.close()

    def distance_among(self, a: int, b: int) -> float:
        return ((self.xdata[a] + self.sizedata_x[a] * 0.5 - self.xdata[b] - self.sizedata_x[b] * 0.5) ** 2
                + (self.ydata[a] + self.sizedata_y[a] * 0.5 - self.ydata[b] - self.sizedata_y[b] * 0.5) ** 2) ** 0.5
    def generate_graph(self,ii):
        self.sizedata_x = np.array(self.sizedata_x)
        self.sizedata_y = np.array(self.sizedata_y)
        self.pdata = np.array(self.pdata)
        self.xsdata = (np.array(self.xsdata) / self.gcell_size[0] * 256).astype(int)
        self.ysdata = (np.array(self.ysdata) / self.gcell_size[1] * 256).astype(int)
        self.xdata = np.array(self.xdata)
        self.ydata = np.array(self.ydata)
        self.circuit_feature = self.features[ii,:,self.xsdata, self.ysdata]
        self.circuit_label = self.labels[ii,:,self.xsdata, self.ysdata]
        node_hv = torch.tensor(
            np.vstack((self.sizedata_x, self.sizedata_y, self.pdata, self.circuit_feature)),
            dtype=torch.float32).t()  # self.xdata是真实坐标，self.xsdata是位于网格的位置
        node_pos = torch.tensor(np.vstack((self.xdata, self.ydata)), dtype=torch.float32).t()  # 真实的坐标
        n_node = node_hv.shape[0]
        # homo_graph
        us, vs = [], []  # 创建只含有节点的图，这里照抄。
        for net, list_node_feats in self.place_net_dict.items():
            nodes = [node_feats[0] for node_feats in list_node_feats]
            us_, vs_ = node_pairs_among(nodes, max_cap=8)
            us.extend(us_)
            vs.extend(vs_)
        homo_graph = add_self_loop(dgl.graph((us, vs), num_nodes=n_node))
        homo_graph.ndata['pos'] = node_pos[:n_node, :]
        homo_graph.ndata['feat'] = node_hv[:n_node, :]
        extra = fo_average(homo_graph)
        homo_graph.ndata.pop('inter')
        homo_graph.ndata.pop('addnlfeat')
        homo_graph.ndata.pop('wts')
        homo_graph.ndata.pop('wtmsg')
        homo_graph.ndata['feat'] = torch.cat([homo_graph.ndata['feat'], extra], dim=1)  # 绑定原有特征和average一轮的特征
        homo_graph.ndata['label'] = torch.tensor(self.circuit_label, dtype=torch.float32).t()  # 绑定label特征（这里是一维的）
        partition_list = get_partition_list_random(homo_graph, int(np.ceil(n_node / self.graph_scale)))
        list_homo_graph = [dgl.node_subgraph(homo_graph, partition) for partition in partition_list]
        print('\thomo_graph generated')



        # hetero_graph
        n_dim = homo_graph.ndata['feat'].shape[1]
        us4, vs4 = [], []
        off_temps = [5678, 7654, 8888, 10035]
        node_pos_code = np.zeros([n_node, n_dim], dtype=np.float_)
        win_x = (self.gcell_coordinate_x[1:] - self.gcell_coordinate_x[:-1]).mean()
        win_y = (self.gcell_coordinate_y[1:] - self.gcell_coordinate_y[:-1]).mean()
        for off_idx, (x_offset, y_offset) in enumerate(
                [(0, 0), (win_x / 2, 0), (0, win_y / 2), (win_x / 2, win_y / 2)]):  # 分别将offtset设为0和binsize/2
            box_node = {}
            iter_sp = tqdm.tqdm(enumerate(zip(self.sizedata_x, self.sizedata_y, self.xdata, self.ydata)), total=n_node) \
                if self.use_tqdm else enumerate(zip(self.sizedata_x, self.sizedata_y, self.xdata, self.ydata))
            for i, (sx, sy, px, py) in iter_sp:  # sy,sy是size，px,py是pos
                if i >= n_node:
                    continue
                if px == 0 and py == 0:
                    continue
                px += x_offset
                py += y_offset
                x_1, x_2 = int(px / win_x), int((px + sx) / win_x)
                y_1, y_2 = int(py / win_y), int((py + sy) / win_y)
                for x in range(x_1, x_2 + 1):
                    for y in range(y_1, y_2 + 1):
                        box_node.setdefault(f'{x}-{y}', []).append(
                            i)  # 对应位置的map加上这个节点 ：例如'519-131':[0，100]这个网格内包含0,100两个节点
                pos_idx = 20 * (((px + sx * 0.5) / win_x) % 1.0) + 5 * (((py + sy * 0.5) / win_y) % 1.0)
                node_pos_code[i, 0::2] += np.sin(
                    np.array([
                        pos_idx / (off_temps[off_idx] ** (di / n_dim)) for di in list(range(n_dim))[0::2]
                    ], dtype=np.float_)
                )
                node_pos_code[i, 1::2] += np.cos(
                    np.array([
                        pos_idx / (off_temps[off_idx] ** ((di - 1) / n_dim)) for di in list(range(n_dim))[1::2]
                    ], dtype=np.float_)
                )  # 位置编码，维度为homo_graph.ndata['feat'].shape[1]
            #             print(pos_idx)
            #             print(node_pos_code[i])
            #             exit(123)
            us, vs = [], []  # 同一个网格内的节点加上边，网格内的最大连接节点数目有限制（一个节点最多和win_cap-1个节点连接）一共四种划分方式，每一种划分方式单独进行，把所有相连的边放入us4,vs4中
            for nodes in box_node.values():
                us_, vs_ = node_pairs_among(nodes, max_cap=5)
                us.extend(us_)
                vs.extend(vs_)
            us4.extend(us)
            vs4.extend(vs)
        dis4 = [[self.distance_among(u, v) / 24] for u, v in zip(us4, vs4)]  # 计算us4,vs4节点之间的距离
        #     print(dis4)
        #     print(np.mean(dis4))
        #     print(np.std(dis4))
        #     exit(123)

        print('\thetero_graph generated 1/2')
        us = []
        vs = []
        he = []
        net_span_feat = []

        for net, list_node_feats in self.place_net_dict.items():
            #nodes = []
            xs,ys,pxs,pys = [],[],[],[]
            for node_feats in list_node_feats:
                px, py = node_pos[node_feats[0], :]
                if not px and not py:
                    continue
                #nodes.append(node_feats[0])
                he.append(node_feats[1:])
                vs.append(net)
                us.append(node_feats[0])
                #pin_px, pin_py = self.xdata[self.componentname2indexmap[node_feats[0]]],self.ydata[self.componentname2indexmap[node_feats[0]]]
                cell_x_right_gcell = self.xsdata[node_feats[0]]
                cell_y_upper_gcell = self.ysdata[node_feats[0]]
                xs.append(int(cell_x_right_gcell))
                ys.append(int(cell_y_upper_gcell))
                pxs.append(px)
                pys.append(py)
            if len(xs) == 0:
                span_v = span_h = 0
                span_pv = span_ph = 0
            else:
                min_x, max_x, min_y, max_y = min(xs), max(xs), min(ys), max(ys)
                span_h = max_x - min_x + 1
                span_v = max_y - min_y + 1
                min_px, max_px, min_py, max_py = min(pxs), max(pxs), min(pys), max(pys)
                span_ph = max_px - min_px + 1
                span_pv = max_py - min_py + 1
            net_span_feat.append([span_v,
                                  span_h,
                                  span_v * span_h,
                                  span_pv, span_ph,
                                  span_pv * span_ph,
                                  len(list_node_feats)])

        net_hv = torch.tensor(net_span_feat, dtype=torch.float32)  ##[net_nums,7]
        net_label = [0] * net_hv.shape[0]
        net_degree = net_hv[:,-1]
        net_label = torch.unsqueeze(torch.tensor(net_label, dtype=torch.float32), dim=-1)
        #构建异构图
        hetero_graph = dgl.heterograph({
            ('node', 'near', 'node'): (us4, vs4),
            ('node', 'pins', 'net'): (us, vs),
            ('net', 'pinned', 'node'): (vs, us),
        }, num_nodes_dict={'node': n_node, 'net': net_hv.shape[0]})
        hetero_graph.nodes['node'].data['hv'] = homo_graph.ndata['feat']
        hetero_graph.nodes['node'].data['pos_code'] = torch.tensor(node_pos_code,
                                                                   dtype=torch.float32)  # 像坐标的编码没有直接加在data['hv']上，而是单独存储
        hetero_graph.nodes['net'].data['hv'] = net_hv
        hetero_graph.nodes['net'].data['degree'] = net_degree  #[一条长向量，一维](现在改成了num_nets,1)(又改回去l)
        hetero_graph.nodes['net'].data['label'] = net_label  # 0  hpwl值 (num_nets,1)
        # hetero_graph.edges['pins'].data['he'] = torch.tensor(he, dtype=torch.float32)
        hetero_graph.edges['pinned'].data['he'] = torch.tensor(he, dtype=torch.float32)
        hetero_graph.edges['near'].data['he'] = torch.tensor(dis4, dtype=torch.float32)
        # 构建好了异构图之后再按照与之前同构图相同的partition_list进行分割
        list_hetero_graph = []
        iter_partition_list = tqdm.tqdm(partition_list, total=len(partition_list)) if self.use_tqdm else partition_list
        for partition in iter_partition_list:
            partition_set = set(partition)
            new_net_degree_dict = {}
            for net_id, node_id in zip(*[ns.tolist() for ns in hetero_graph.edges(etype='pinned')]):
                if node_id in partition_set:
                    new_net_degree_dict.setdefault(net_id, 0)
                    new_net_degree_dict[net_id] += 1
            keep_nets_id = np.array(list(new_net_degree_dict.keys()))
            keep_nets_degree = np.array(list(new_net_degree_dict.values()))
            good_nets = np.abs(net_degree[keep_nets_id] - keep_nets_degree) < 1e-5
            #         print(np.sum(good_nets), len(good_nets))
            keep_nets_id = keep_nets_id[good_nets]
            part_hetero_graph = dgl.node_subgraph(hetero_graph, nodes={'node': partition, 'net': keep_nets_id})
            #         new_net_degree = torch.unsqueeze(torch.tensor(list(new_net_degree_dict.values()), dtype=torch.float32), dim=-1)
            #         part_hetero_graph.nodes['net'].data['new_degree'] = new_net_degree
            #         print(part_hetero_graph.nodes['net'].data['degree'])
            #         print(part_hetero_graph.nodes['net'].data['new_degree'])
            #         exit(123)
            list_hetero_graph.append(part_hetero_graph)
        print('\thetero_graph generated 2/2')
        # 写入图
        list_tuple_graph = list(zip(list_homo_graph, list_hetero_graph))
        with open(self.graph_pickle, 'wb+') as fp:
            pickle.dump(list_tuple_graph, fp)
        return list_tuple_graph



    def get_circuit_features(self, args):
            # congestion
            self.congestion_root = 'routability_features'
            self.congestion_feature_list = [os.path.join(self.data_root, self.congestion_root, 'macro_region'),
                                            os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY'),
                                            os.path.join(self.data_root, self.congestion_root, 'RUDY/RUDY_pin')]
            self.congestion_label_list = [os.path.join(self.data_root, self.congestion_root,
                                                       'congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow'),
                                          os.path.join(self.data_root, self.congestion_root,
                                                       'congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow')]
            # DRC task
            feature_list = ['routability_features/macro_region',
                            'routability_features/cell_density',
                            'routability_features/RUDY/RUDY_long',
                            'routability_features/RUDY/RUDY_short',
                            'routability_features/RUDY/RUDY_pin_long',
                            'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_horizontal_overflow',
                            'routability_features/congestion/congestion_early_global_routing/overflow_based/congestion_eGR_vertical_overflow',
                            'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_horizontal_overflow',
                            'routability_features/congestion/congestion_global_routing/overflow_based/congestion_GR_vertical_overflow']
            label_list = ['routability_features/DRC/DRC_all']
            self.drc_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
            self.drc_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
            # IR task
            feature_list = ['IR_drop_features_decompressed/power_i', 'IR_drop_features_decompressed/power_s',
                            'IR_drop_features_decompressed/power_sca', 'IR_drop_features_decompressed/power_all',
                            'IR_drop_features_decompressed/power_t']

            label_list = ['IR_drop_features_decompressed/IR_drop']
            self.ir_feature_list = [os.path.join(self.data_root, fea) for fea in feature_list]
            self.ir_label_list = [os.path.join(self.data_root, fea) for fea in label_list]
            # thermal_task
            self.thermal_label_list = [os.path.join(self.data_root, "hotmap")]

            self.feature = args.feature
            assert self.feature in ["congestion", "DRC", "IR_drop", "thermal", "all"]
            self.label = args.label
            assert self.label in ["congestion", "DRC", "IR_drop", "thermal", "all"]
            # if self.label == "thermal":
            #    assert args.pretrain == False
            # self.get_features_and_labels_from_task()
            if self.feature == "all" or self.feature == "thermal":
                self.get_all_features()
            else:
                self.get_features_one_task()
            if self.label == "all":
                self.get_all_labels()
            else:
                self.get_labels_one_task()
            self.feature_dim = self.features.shape[1]
            self.label_dim = self.labels.shape[1]

    def get_all_features(self):
        features = []
        for name in self.name_list:
            out_feature_list = []
            for feature_name in self.congestion_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                feature = torch.tensor(std(resize(feature)))
                out_feature_list.append(feature)
            for feature_name in self.drc_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                feature = torch.tensor(std(resize(feature)))
                out_feature_list.append(feature)
            for feature_name in self.ir_feature_list:
                feature = np.load(os.path.join(feature_name, name))
                if feature_name.endswith('power_t'):
                    for i in range(20):
                        slice = feature[i, :, :]
                        out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                else:
                    feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                    out_feature_list.append(feature)
            features.append(torch.stack(out_feature_list))
        self.features = torch.stack(features)



    def get_features_one_task(self):
        features = []
        if self.feature == "congestion":
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.congestion_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        elif self.feature == 'DRC':
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.drc_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    feature = torch.tensor(std(resize(feature)))
                    out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        elif self.feature == 'IR_drop':
            for name in self.name_list:
                out_feature_list = []
                for feature_name in self.ir_feature_list:
                    feature = np.load(os.path.join(feature_name, name))
                    if feature_name.endswith('power_t'):
                        for i in range(20):
                            slice = feature[i, :, :]
                            out_feature_list.append(torch.tensor(std(resize_cv2(slice))))
                    else:
                        feature = torch.tensor(std(resize_cv2(feature.squeeze())))
                        out_feature_list.append(feature)
                features.append(torch.stack(out_feature_list))
        self.features = torch.stack(features)

    def get_all_labels(self):
        labels = []
        for name in self.name_list:
            out_label_list = []
            #congestion
            for label_name in self.congestion_label_list:
                label = np.load(os.path.join(label_name, name))
                label = torch.tensor(resize(label))
                out_label_list.append(label)
            #drc
            for label_name in self.drc_label_list:
                label = np.load(os.path.join(label_name, name))
                label = np.clip(label, 0, 200)
                label = torch.tensor(resize_cv2(label) / 200)
                out_label_list.append(label)
            #ir
            for label_name in self.ir_label_list:
                label = np.load(os.path.join(label_name, name))
                label = np.squeeze(label)
                label = np.clip(label, 1e-6, 50)
                label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                out_label_list.append(label)
            labels.append(torch.stack(out_label_list))
        self.labels = torch.stack(labels)


    def get_labels_one_task(self):
        labels = []
        if self.label == "congestion":
            for name in self.name_list:
                out_label_list = []
                for label_name in self.congestion_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = torch.tensor(resize(label))
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))
        elif self.label == 'DRC':
            for name in self.name_list:
                out_label_list = []
                for label_name in self.drc_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.clip(label, 0, 200)
                    label = torch.tensor(resize_cv2(label) / 200)
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))
        elif self.label == 'IR_drop':
            for name in self.name_list:
                out_label_list = []
                for label_name in self.ir_label_list:
                    label = np.load(os.path.join(label_name, name))
                    label = np.squeeze(label)
                    label = np.clip(label, 1e-6, 50)
                    label = torch.tensor((np.log10(resize_cv2(label)) + 6) / (np.log10(50) + 6))
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))
        elif self.label == "thermal":
            for name in self.name_list:
                out_label_list = []
                for label_name in self.thermal_label_list:
                    label = np.load(os.path.join(label_name, name+".npy"))
                    label = torch.tensor(resize(label))
                    out_label_list.append(label)
                labels.append(torch.stack(out_label_list))

        self.labels = torch.stack(labels)


    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        feature_map = self.features[idx]
        labels = self.labels[idx]
        return feature_map, labels








