import torch
import numpy as np
import pickle, json, time, re, sys
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl

from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import os

class TextEmbedDataset(torch.utils.data.Dataset):
    def __init__(self):
        super().__init__()
        self.embed = []
        self.label = []
        
    def add_text_embed_data(self, embed_dict, label_data):
        self.embed.append(embed_dict)
        self.label.append(label_data)
    
    def __getitem__(self, idx):
        return self.embed[idx]
    
    def __len__(self):
        return len(self.embed)

    def collate(self, samples):
        embed= list(samples)
        outputs, embeddings = [], []

        for i in range(len(embed)):
            outputs.append(embed[i]['outputs'])
            embeddings.append(embed[i]['embeddings'])

        # Determine the maximum size along dimension 1
        max_output_size_1 = max([output.size(1) for output in outputs])
        max_output_size_2 = max([output.size(2) for output in outputs])

        # Pad each tensor to the maximum size
        outputs = [F.pad(output, (0, max_output_size_2 - output.size(2), 0, max_output_size_1 - output.size(1)), "constant", 0) for output in outputs]

        # list to tensor
        outputs = torch.stack(outputs).squeeze(1)


        ### stack embeddings
        embeddings = torch.stack(embeddings)

        print(outputs.shape, embeddings.shape)

        return (
            (outputs, embeddings)
        )

    def collate_bak(self, samples):
        embed= list(samples)

        inputs, inputs_att_mask, outputs, embeddings = [], [], [], []

        for i in range(len(embed)):
            inputs.append(embed[i]['inputs'])
            inputs_att_mask.append(embed[i]['inputs_att_mask'])
            outputs.append(embed[i]['outputs'])
            embeddings.append(embed[i]['embeddings'])

        # Determine the maximum size along dimension 1
        max_input_size = max([input.size(1) for input in inputs])
        max_att_mask_size = max([att_mask.size(1) for att_mask in inputs_att_mask])
        max_output_size_1 = max([output.size(1) for output in outputs])
        max_output_size_2 = max([output.size(2) for output in outputs])

        # Pad each tensor to the maximum size
        inputs = [F.pad(input, (0, max_input_size - input.size(1)), "constant", 0) for input in inputs]
        inputs_att_mask = [F.pad(att_mask, (0, max_att_mask_size - att_mask.size(1)), "constant", 0) for att_mask in inputs_att_mask]
        outputs = [F.pad(output, (0, max_output_size_2 - output.size(2), 0, max_output_size_1 - output.size(1)), "constant", 0) for output in outputs]

        # list to tensor
        inputs = torch.stack(inputs)
        inputs_att_mask = torch.stack(inputs_att_mask)
        outputs = torch.stack(outputs).squeeze(1)


        ### stack embeddings
        embeddings = torch.stack(embeddings)

        print(inputs.shape, inputs_att_mask.shape, outputs.shape, embeddings.shape)


        return (
            outputs
        )

def run_one_ep(design_name, ep):
    
    folder_dir = f'./llm_enc/rtl_emb/{cmd}/{design_name}/{ep}.pkl'
    with open(folder_dir, 'rb') as f:
            emb_tuple = pickle.load(f)

    batch_dict, outputs, embeddings = emb_tuple
    inputs = batch_dict['input_ids'].cpu()
    inputs_att_mask = batch_dict['attention_mask'].cpu()
    outputs = outputs.cpu()
    embeddings = embeddings.cpu()

    print(inputs.shape, inputs_att_mask.shape, outputs.shape, embeddings.shape)

    embed_dict = {}
    # embed_dict['inputs'] = inputs
    # embed_dict['inputs_att_mask'] = inputs_att_mask
    embed_dict['outputs'] = outputs
    embed_dict['embeddings'] = embeddings

    dataset.add_text_embed_data(embed_dict, [])


def save_dataset(design_lst, tag):
    global dataset, data_dict
    data_dict = {}

    print(f'Tag: {tag} CMD: {cmd}')
    dataset = TextEmbedDataset()
    for design in design_lst:
        with open (f"/home/coguest5/hdl_fusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)

        for ep in reg_lst:
            # print(design + ' ' + ep)
            if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                print("Not exist: ", design, ep)
                continue
            if not os.path.exists(f"/home/coguest5/hdl_fusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                print("Not exist: ", design, ep)
                continue
            run_one_ep(design, ep)
    



        
        # dataset.convert()
    print(len(dataset))
    with open(f'./data_bench/dataset_{tag}_{cmd}.pkl', 'wb') as f:
        pickle.dump(dataset, f)

    ### load dataset
    # text_loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=dataset.collate)
    # for data in text_loader:
    #     print(data)
    #     break




if __name__ == '__main__':
    global cmd
    cmd = "ori"

        
    # with open("../../dataset_js/design_lst.json", 'r') as f:
    #     design_lst = json.load(f)
    
    with open ("../dataset_js/train_lst.json", 'r') as f:
        train_design_lst = json.load(f)
    with open ("../dataset_js/valid_lst.json", 'r') as f:
        valid_design_lst = json.load(f)

    # with open ("../../dataset_js/test_lst.json", 'r') as f:
    #     test_design_lst = json.load(f)
    

    with open ("../dataset_js/demo_lst.json", 'r') as f:
        demo_lst = json.load(f)
    # with open ("../dataset_js/sft_none_lst.json", 'r') as f:
    #     sft_design_lst_all = json.load(f)

    save_dataset(train_design_lst, 'train')
    # save_dataset(test_design_lst, 'test')

    save_dataset(valid_design_lst, 'valid')
    # save_dataset(sft_design_lst_all, 'sft')

    save_dataset(demo_lst, 'demo')

    # with open ("../dataset_js/sft_ft_lst.json", 'r') as f:
    #     sft_design_lst_all = json.load(f)
    # save_dataset(sft_design_lst_all, 'sft')

    # with open ("../dataset_js/sft_ft_lst_4.json", 'r') as f:
    #     sft_design_lst_all = json.load(f)
    # save_dataset(sft_design_lst_all, 'sft_4')

    # with open ("../dataset_js/sft_ft_lst_16.json", 'r') as f:
    #     sft_design_lst_all = json.load(f)
    # save_dataset(sft_design_lst_all, 'sft_16')