import torch
import numpy as np
import pickle, json, time, re, sys, os
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer


def run_one_design(design_name):
    # print(design_name, ep)
    folder_dir = f'/home/coguest5/CircuitFusion/data_collect/vlg/data/ori_vlg'
    with open(f'{folder_dir}/{design_name}.v', 'r') as f:
            lines = f.readlines()
    documents = ""
    for line in lines:
        line = re.sub(r'\n', '', line)
        documents += line

    return documents


def get_dataset(design_lst):
    # model = SentenceTransformer("intfloat/e5-mistral-7b-instruct")
    # model.max_seq_length = 4096

    gpu='auto'
    # model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct', device_map=gpu).cuda()
    tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')

    for design in design_lst:
        print("Current design: ", design)

        with open(f"/home/coguest5/CircuitFusion/data_collect/dataset/rtl_graph/ori/{design}_ori.pkl", 'rb') as f:
            graph = pickle.load(f)
            graph = nx.DiGraph(graph)
        num_node = len(graph.nodes)
        stat_lst.append(num_node)


if __name__ == '__main__':

    global stat_lst
    stat_lst = []
    bench = 'itc'
    bench = 'opencores'
    bench = 'vex'
    bench = 'chipyard'
    global design_lst_all
    with open(f"/home/coguest5/CircuitFusion/dataset/dataset_js/design_{bench}.json", 'r') as f:
        design_lst_all = json.load(f)
    
    get_dataset(design_lst_all)
    print(f'Benchmark: {bench}')
    print(f"Average number of tokens: {np.mean(stat_lst)}")
    print(f"Max number of tokens: {np.max(stat_lst)}")
    print(f"Min number of tokens: {np.min(stat_lst)}")
    print(f'Median number of tokens: {np.median(stat_lst)}')