import torch, os
import numpy as np
import pickle, json, time, re, sys
from DG import Graph, Node
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl

from graph2dgl_net import node_feat_extra_net

class NetDataset(dgl.data.DGLDataset):
    def __init__(self):
        super().__init__(name='my_dataset')
        self.graphs = []
        self.label = []
        
    def add_graph_data(self, dgl_graph, label_data):
        self.graphs.append(dgl_graph)
        self.label.append(label_data)
    
    def __getitem__(self, idx):
        return self.graphs[idx]
    
    def __len__(self):
        return len(self.graphs)



def run_one_ep(design_name, ep, node_dict):

    
    with open(f'{folder_dir}/{design_name}/{ep}.pkl', 'rb') as f:
        g_nx = pickle.load(f)

    feat_matrix = []
    lable_matrix = []
    feat_dict = {}
    for node_name in g_nx.nodes():
        if node_name not in node_dict:
            feat_vec = np.array([0 for i in range(28)])
        else:
            node = node_dict[node_name]
            feat_vec, label_vec = node_feat_extra_net(node_name, node, g_nx, node_dict)
        feat_dict[node_name] = torch.FloatTensor(feat_vec)
        feat_matrix.append(feat_vec)
        lable_matrix.append(label_vec)
    feat_matrix = np.array(feat_matrix)
    feat_matrix = torch.FloatTensor(feat_matrix)

    lable_matrix = np.array(lable_matrix)
    lable_matrix = torch.FloatTensor(lable_matrix)

    dgl_graph = from_networkx(g_nx)
    dgl_graph.ndata['feat'] = feat_matrix
    dgl_graph.ndata['label'] = lable_matrix
    # print(dgl_graph)

    dataset.add_graph_data(dgl_graph, lable_matrix)

def save_dataset(design_lst, tag):
    global dataset, data_dict
    dataset = NetDataset()
    global folder_dir
    
    if cmd in ['ori', "pos"]:
        if cmd == 'ori':
            folder_dir = f'/home/coguest5/hdl_fusion/data_collect/dataset_net/cone_graph/ori/'
        elif cmd == 'pos':
            folder_dir = f'/home/coguest5/hdl_fusion/data_collect/dataset_net/cone_graph/pos1/'
        print(f'Tag: {tag} CMD: {cmd}')
        dataset = NetDataset()
        for design in design_lst:
            with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
                reg_lst = json.load(f)
            with open(f'{folder_dir}/{design}/{design}_node_dict.pkl', 'rb') as f:
                node_dict = pickle.load(f)
            for ep in reg_lst:
                # print(design + ' ' + ep)
                if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                    print("Not exist: ", design, ep)
                    continue
                if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                    print("Not exist: ", design, ep)
                    continue
                run_one_ep(design, ep, node_dict)

    elif cmd == 'neg':
        folder_dir = f'/home/coguest5/hdl_fusion/data_collect/dataset_net/cone_graph/ori/'
        print(f'Tag: {tag} CMD: {cmd}')
        with open (f"/home/coguest5/hdl_fusion/text_enc/bert_cl/neg_map/reg_map.json", 'r') as f:
            reg_map = json.load(f)
        dataset = NetDataset()

        node_dict_dct = {}
        for design in design_lst_all:
            with open(f'{folder_dir}/{design}/{design}_node_dict.pkl', 'rb') as f:
                node_dict_dct[design] = pickle.load(f)

        idx = 0
        for design in design_lst:
            with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
                reg_lst = json.load(f)

            for ep in reg_lst:
                # print(design + ' ' + ep)
                if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                    print("Not exist: ", design, ep)
                    continue
                if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                    print("Not exist: ", design, ep)
                    continue
                neg_mapped = reg_map[design][ep]['neg1']
                design_neg, ep_neg = neg_mapped[0], neg_mapped[1]
                node_dict_neg = node_dict_dct[design_neg]
                run_one_ep(design_neg, ep_neg, node_dict_neg)
                idx += 1
                print(idx)


    print(len(dataset))
    with open(f'./data_bench/dataset_{tag}_{cmd}.pkl', 'wb') as f:
        pickle.dump(dataset, f)


if __name__ == '__main__':
    global cmd

    cmd = "ori"
    # cmd = "pos"
    # cmd = "neg"

    print("Current cmd: ", cmd)

    global design_lst_all
    with open ("../dataset_js/design_all.json", 'r') as f:
        design_lst_all = json.load(f)
    
    # with open ("../dataset_js/train_lst.json", 'r') as f:
    #     train_design_lst = json.load(f)
    # with open ("../dataset_js/valid_lst.json", 'r') as f:
    #     valid_design_lst = json.load(f)
    # save_dataset(train_design_lst, 'train')
    # save_dataset(valid_design_lst, 'valid')

    with open ("../dataset_js/demo_lst.json", 'r') as f:
        demo_design_lst = json.load(f)
    save_dataset(demo_design_lst, 'demo')

