import torch
import numpy as np
import pickle, json, time, re, sys, os
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl
from torch import Tensor
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
from accelerate import Accelerator

accelerator = Accelerator()


def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
    left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
    if left_padding:
        return last_hidden_states[:, -1]
    else:
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_states.shape[0]
        return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]


def get_detailed_instruct(task_description: str, query: str) -> str:
    return f'Instruct: {task_description}\nQuery: {query}'

def run_one_ep(design_name, ep):
    # print(design_name, ep)
    folder_dir = f'/home/coguest5/hdl_fusion/data_collect/dataset/{cmd}/vlg'
    with open(f'{folder_dir}/{design_name}/{ep}.v', 'r') as f:
            lines = f.readlines()
    documents = ""
    for line in lines:
        line = re.sub(r'\n', '', line)
        documents += line

    return documents


def get_dataset(design_lst):
    # model = SentenceTransformer("intfloat/e5-mistral-7b-instruct")
    # model.max_seq_length = 4096
    tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-mistral-7b-instruct')
    model = AutoModel.from_pretrained('intfloat/e5-mistral-7b-instruct',torch_dtype=torch.float16)

    (tokenizer, model) = accelerator.prepare(tokenizer, model)

    for design in design_lst:
        print("Current design: ", design)
        with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        # task_name_to_instruct = {"example": "Please act as a professional VLSI Verilog designer. ",}

        # query_prefix = "Instruct: "+task_name_to_instruct["example"]+"\nQuery: "
        # queries = [
        # "analyze the functionality of the given Verilog code and retrieve the most similar design code"
        # ]

        save_dir = f"../rtl_emb/{cmd}/{design}"
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        else:
            os.system(f"rm -rf {save_dir}")
            os.makedirs(save_dir)
    
        time_start = time.time()
        for ep in reg_lst:
            documents = run_one_ep(design, ep)

            task = 'Given a web search query, retrieve relevant passages that answer the query'
            queries = [
                get_detailed_instruct(task, 'how much protein should a female eat'),
                get_detailed_instruct(task, 'summit define')
            ]
            # No need to add instruction for retrieval documents
            # input_texts = queries + documents
            input_texts = accelerator.prepare(documents)

            

            max_length = 4096
            # Tokenize the input texts
            batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt').to(accelerator.device)

            outputs = model(**batch_dict)
            embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

            document_embeddings = F.normalize(embeddings, p=2, dim=1)
            # scores = (embeddings[:2] @ embeddings[2:].T) * 100
            with open(f"{save_dir}/{ep}.pkl", 'wb') as f:
                pickle.dump((batch_dict, outputs.last_hidden_state, embeddings), f)

        time_end = time.time()
        print("Time used: ", time_end-time_start)
        with open(f"../runtime/{design}.txt", 'w') as f:
            f.write(str(time_end-time_start))
    

if __name__ == '__main__':

    global design_lst_all, cmd
    cmd = "ori"
    cmd = "pos"
    with open("/home/coguest5/hdl_fusion/dataset_js/design_all.json", 'r') as f:
        design_lst_all = json.load(f)
    
    get_dataset(design_lst_all)