import torch
import numpy as np
import pickle, json, time, re, sys
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import os
from dgl.data import DGLDataset
from DG import Graph, Node
import copy

class RegGraphDataset(dgl.data.DGLDataset):
    def __init__(self):
        super().__init__(name='my_dataset')
        self.graphs = []
        self.label = []
        
    def add_graph_data(self, dgl_graph, label_data):
        self.graphs.append(dgl_graph)
        self.label.append(label_data)
    
    def __getitem__(self, idx):
        return (self.graphs[idx], self.label[idx])
    
    def __len__(self):
        return len(self.graphs)


def run_one_design(design):
    
    folder_dir = f'/home/coguest5/hdl_fusion/data_collect/depend_graph/data'
    with open(f'{folder_dir}/{design}_ep_graph.pkl', 'rb') as f:
        graph = pickle.load(f)
    with open(f'{folder_dir}/{design}_ep_node_dict.pkl', 'rb') as f:
        node_dict = pickle.load(f)

    

    with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
        reg_lst = json.load(f)
    reg_set = set(reg_lst)
    g_nx = nx.DiGraph(graph)
    g_nx_cp = copy.deepcopy(g_nx)

    node2int_dct = {}
    for ep in g_nx.nodes():
        node2int_dct[ep] = len(node2int_dct)


    node_lst = []
    slack_lst = []

    for ep in reg_lst:
        if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
            reg_set.remove(ep)
            continue
        if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
            reg_set.remove(ep)
            continue

    for node in g_nx_cp.nodes():
        if not node in reg_set:
            g_nx.remove_node(node)
            continue
        ### slack label
        with open (f"/home/coguest5/hdl_fusion/data_collect/label/ppa/cone_pwr_area/{design}/{node}.json", 'r') as f:
            cone_ppa_dct = json.load(f)
        node_lst.append(node)
        slack_label = cone_ppa_dct['slack']
        slack_lst.append(slack_label)
    
    dgl_graph = from_networkx(g_nx)
    dgl_graph.ndata['label_slack'] = torch.tensor(np.array(slack_lst))

    ### add graph-level label
    with open(f"/home/coguest5/hdl_fusion/data_collect/label/ppa/json/{design}/ppa.json", 'r') as f:
        ppa_dct = json.load(f)

    dataset = RegGraphDataset()
    dataset.add_graph_data(dgl_graph, (node_lst, ppa_dct['area'], ppa_dct['pwr'], ppa_dct['wns'], ppa_dct['tns']))
    with open(f'/home/coguest5/hdl_fusion/dataset/dataset_finetune/data_bench/{design}.pkl', 'wb') as f:
        pickle.dump(dataset, f)

def save_dataset_one_design(design_lst):
    
    for design in design_lst:
        print('Current Design: ', design)
        run_one_design(design)



if __name__ == '__main__':
    global cmd
    cmd = "ori"


    with open ("../dataset_js/design_all.json", 'r') as f:
        design_lst = json.load(f)



    save_dataset_one_design(design_lst)
