import argparse
import os, pickle, json
import numpy as np
import ruamel_yaml as yaml
from accelerate import Accelerator
import torch
from transformers.optimization import (
    AdamW,
    get_polynomial_decay_schedule_with_warmup,
)

from dataset_proc.load_dataset import load_train_valid_dataset, load_test_dataset,MyDataset, TextEmbedDataset

from models.model_pretrain import RTL_Fusion
from utils.eval import regression_metrics

from accelerate import DistributedDataParallelKwargs

from torch.utils.tensorboard import SummaryWriter  


date ='pretrain_align_sync_0817'
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

def zero_shot_infer_one_design(design, rtl_fusion):
    print('Current Design: ', design)
    test_loader = load_test_dataset(design, batch_size=16)
    graph_loader, summary_loader, text_loader = test_loader
    design_embed_all = np.empty((0, embed_dim))
    for _, data in enumerate(zip(graph_loader, summary_loader, text_loader)):
        graph_data, bactched_summary, text_data = data
        summary_data = bactched_summary[0]
        fusion_emb = rtl_fusion((graph_data, summary_data, text_data), mode='infer')
        fusion_emb = fusion_emb.cpu().detach().numpy()
        design_embed_all = np.append(design_embed_all, fusion_emb, axis=0)

    with open(f".visualization/embeds/{design}.pkl", 'wb') as f:
        pickle.dump(design_embed_all, f)

def run_all_design(design_lst, rtl_fusion):
    
    for design in design_lst:
        zero_shot_infer_one_design(design, rtl_fusion)



if __name__ == '__main__':
    clear = True

    global embed_dim
    embed_dim = 768

    global epoch
    
    epoch = 20

    print(f"Epoch: {epoch}")

    model_save_dir = f"./pretrain_model/{date}"
    embed_save_dir = f'./embeds/{date}_{epoch}'
    if not os.path.exists(embed_save_dir):
        os.mkdir(embed_save_dir)

    rtl_fusion = torch.load(f"{model_save_dir}/rtl_fusion.{epoch}.pt")


    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/test_lst.json", 'r') as f:
        test_lst = json.load(f)

    run_all_design(test_lst, rtl_fusion)
