import torch
import numpy as np
import pickle, json, time, re, sys, os
import networkx as nx
from multiprocessing import Pool
import dgl
from dgl import from_networkx
import dgl
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer


def get_train_emb(design_lst):
    emb_lst = []
    design_ep_lst = []
    for design in design_lst:
        print("Current design: ", design)
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        
        for ep in reg_lst:
            emb_dir = f"../rtl_emb/{design}/{ep}.pkl"
            if not os.path.exists(emb_dir):
                # print("Not exist: ", design, ep)
                continue
            with open(emb_dir, 'rb') as f:
                emb = pickle.load(f)
            emb_lst.append(emb)
            design_ep_lst.append((design, ep))
    
    emb_lst = np.array(emb_lst)
    return emb_lst, design_ep_lst


def zero_shot_retrieve(test_lst, emb_lst, train_ep_lst):
    for design in test_lst:
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        
        for ep in reg_lst:
            print("Current design: ", design, ep)
            emb_dir = f"../rtl_emb/{design}/{ep}.pkl"
            with open(emb_dir, 'rb') as f:
                test_emb = pickle.load(f)
            ### calculate similarity between emb and emb_lst
            scores = (test_emb @ emb_lst.T) * 100
            # print(scores)
            top_k_idx = np.argsort(scores)[::-1][:5]
            # print(top_k_idx)
            for idx in top_k_idx:
                design_r = train_ep_lst[idx][0]
                ep_r = train_ep_lst[idx][1]
                print(design_r, ep_r)
            print('\n')
                # ppa = design_ppa_dict[cone][ep]
                # func = design_func_dict[cone][ep]
                # 
                # emb = rtl_train_emb[idx]
                # feat_dict = {"emb": list(emb.cpu().detach().numpy()), "ppa": ppa, "func": func, "design": cone, "ep": ep}
                # ep_dict[ep_y].append(feat_dict)
                # 
                # test_dict_all[cone_y].update(ep_dict)
        input()
    

if __name__ == '__main__':

    global train_lst, test_lst
    with open("/home/coguest5/CircuitFusion/dataset_js/train_lst.json", 'r') as f:
        train_lst = json.load(f)
    with open("/home/coguest5/CircuitFusion/dataset_js/test_lst.json", 'r') as f:
        test_lst = json.load(f)

    
    emb_lst, design_ep_lst = get_train_emb(train_lst)
    zero_shot_retrieve(test_lst, emb_lst, design_ep_lst)