import argparse
import os, pickle, json
import numpy as np
import ruamel_yaml as yaml
from accelerate import Accelerator
import torch
from transformers.optimization import (
    AdamW,
    get_polynomial_decay_schedule_with_warmup,
)

from dataset_proc.load_dataset import load_train_valid_dataset, load_test_dataset,MyDataset, TextEmbedDataset

from models.model_pretrain import RTL_Fusion
from utils.eval import regression_metrics

from accelerate import DistributedDataParallelKwargs

from torch.utils.tensorboard import SummaryWriter  


date ='pretrain_align_sync_0817'

os.environ["CUDA_VISIBLE_DEVICES"] = "2"

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

def calculate_training_pool(rtl_fusion):
    #### get testing design metric ####
    testing_design_metric_dict = {}
    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/test_lst.json", 'r') as f:
        test_lst = json.load(f)
    for design in test_lst:
        ep_metric_all = []
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        for ep in reg_lst:
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                continue
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                continue

            ### PPA label ###
            with open (f"/home/coguest5/CircuitFusion/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
                cone_ppa_dct = json.load(f)
            ep_metric_all.append([design, ep, cone_ppa_dct['slack'], cone_ppa_dct['pwr'], cone_ppa_dct['area']])
        testing_design_metric_dict[design] = ep_metric_all
    with open (f"{embed_save_dir}/testing_design_metric_dict.json", 'w') as f:
        json.dump(testing_design_metric_dict, f, indent=4)


    #### get training pool metric ####
    training_pool_metric_all = []
    with open("/home/coguest5/CircuitFusion//dataset/dataset_js/train_lst.json", 'r') as f:
        train_lst = json.load(f)
    for design in train_lst:
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        for ep in reg_lst:
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                continue
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                continue

            ### PPA label ###
            with open (f"/home/coguest5/CircuitFusion/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
                cone_ppa_dct = json.load(f)
            ep_lst = [design, ep, cone_ppa_dct['slack'], cone_ppa_dct['pwr'], cone_ppa_dct['area']]
            training_pool_metric_all.append(ep_lst)

    
    with open (f"{embed_save_dir}/training_pool_metric_all.json", 'w') as f:
        json.dump(training_pool_metric_all, f)

    #### get training pool embed ####

    accelerator.print("Loading Dataset Training Pool ...")
    train_rtl_loader = load_train_valid_dataset(batch_size=16, train_valid="train")
    accelerator.print("Dataset Loaded!")
    (rtl_fusion, train_rtl_loader) = accelerator.prepare(rtl_fusion, train_rtl_loader)
    graph_ori_lader_train, _, _, summary_loader_train, text_loader_train = train_rtl_loader

    training_pool_embed_all = np.empty((0, embed_dim))
    for idx, data in enumerate(zip(graph_ori_lader_train, summary_loader_train, text_loader_train)):
        graph_data, bactched_summary, text_data = data
        summary_data = bactched_summary[0]
        fusion_emb = rtl_fusion((graph_data, summary_data, text_data), mode='infer')
        fusion_emb = fusion_emb.cpu().detach().numpy()
        training_pool_embed_all = np.append(training_pool_embed_all, fusion_emb, axis=0)
    
    print(training_pool_embed_all.shape)
    
    with open (f"{embed_save_dir}/training_pool_embed_all.pkl", 'wb') as f:
        pickle.dump(training_pool_embed_all, f)
    
    return training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict





def zero_shot_infer_one_design(design, rtl_fusion,\
                               training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict):
    print('Current Design: ', design)
    test_loader = load_test_dataset(design, batch_size=16)
    graph_loader, summary_loader, text_loader = test_loader
    design_embed_all = np.empty((0, embed_dim))
    for _, data in enumerate(zip(graph_loader, summary_loader, text_loader)):
        graph_data, bactched_summary, text_data = data
        summary_data = bactched_summary[0]
        fusion_emb = rtl_fusion((graph_data, summary_data, text_data), mode='infer')
        fusion_emb = fusion_emb.cpu().detach().numpy()
        design_embed_all = np.append(design_embed_all, fusion_emb, axis=0)

    
    slack_real, pwr_real, area_real = [], [], []
    slack_pred_r1, pwr_pred_r1, area_pred_r1 = [], [], []
    slack_pred_r3, pwr_pred_r3, area_pred_r3 = [], [], []
    slack_pred_r5, pwr_pred_r5, area_pred_r5 = [], [], []
    slack_pred_r10, pwr_pred_r10, area_pred_r10 = [], [], []
    
    for idx, test_ep_embed in enumerate(design_embed_all):
        ### get label ###
        ep_lst = testing_design_metric_dict[design][idx]
        ep = testing_design_metric_dict[design][idx][1]
        label_slack = testing_design_metric_dict[design][idx][2]
        label_pwr = testing_design_metric_dict[design][idx][3]
        label_area = testing_design_metric_dict[design][idx][4]
        slack_real.append(label_slack)
        pwr_real.append(label_pwr)
        area_real.append(label_area)

        ### calculate distance scores ###
        scores = (test_ep_embed @ training_pool_embed_all.T) * 100
        top_k_idx = np.argsort(scores)[::-1][:10]

        # print(f"Top 5 similar EPs for {ep_lst}:")
        # for idx in top_k_idx:
        #     print(training_pool_metric_all[idx])

        slack_pred_r1.append(training_pool_metric_all[top_k_idx[0]][2])
        pwr_pred_r1.append(training_pool_metric_all[top_k_idx[0]][3])
        area_pred_r1.append(training_pool_metric_all[top_k_idx[0]][4])
        slack_pred_r3.append(np.mean([training_pool_metric_all[top_k_idx[i]][2] for i in range(3)]))
        pwr_pred_r3.append(np.mean([training_pool_metric_all[top_k_idx[i]][3] for i in range(3)]))
        area_pred_r3.append(np.mean([training_pool_metric_all[top_k_idx[i]][4] for i in range(3)]))
        slack_pred_r5.append(np.mean([training_pool_metric_all[top_k_idx[i]][2] for i in range(5)]))
        pwr_pred_r5.append(np.mean([training_pool_metric_all[top_k_idx[i]][3] for i in range(5)]))
        area_pred_r5.append(np.mean([training_pool_metric_all[top_k_idx[i]][4] for i in range(5)]))
        slack_pred_r10.append(np.mean([training_pool_metric_all[top_k_idx[i]][2] for i in range(10)]))
        pwr_pred_r10.append(np.mean([training_pool_metric_all[top_k_idx[i]][3] for i in range(10)]))
        area_pred_r10.append(np.mean([training_pool_metric_all[top_k_idx[i]][4] for i in range(10)]))

        # exit()
    print('\n---- Slack Eval ----')
    print('R@1')
    metric_slack_r1 = regression_metrics(slack_pred_r1, slack_real)
    print('R@3')
    metric_slack_r3 = regression_metrics(slack_pred_r3, slack_real)
    print('R@5')
    metric_slack_r5 = regression_metrics(slack_pred_r5, slack_real)
    print('R@10')
    metric_slack_r10 = regression_metrics(slack_pred_r10, slack_real)
    print('\n---- Power Eval ----')
    print('R@1')
    metric_pwr_r1 = regression_metrics(pwr_pred_r1, pwr_real)
    print('R@3')
    metric_pwr_r3 = regression_metrics(pwr_pred_r3, pwr_real)
    print('R@5')
    metric_pwr_r5 = regression_metrics(pwr_pred_r5, pwr_real)
    print('R@10')
    metric_pwr_r10 = regression_metrics(pwr_pred_r10, pwr_real)
    print('\n---- Area Eval ----')
    print('R@1')
    metric_area_r1 = regression_metrics(area_pred_r1, area_real)
    print('R@3')
    metric_area_r3 = regression_metrics(area_pred_r3, area_real)
    print('R@5')
    metric_area_r5 = regression_metrics(area_pred_r5, area_real)
    print('R@10')
    metric_area_r10 = regression_metrics(area_pred_r10, area_real)
    # exit()
    return (metric_slack_r1, metric_slack_r3, metric_slack_r5, metric_slack_r10,\
            metric_pwr_r1, metric_pwr_r3, metric_pwr_r5, metric_pwr_r10,\
            metric_area_r1, metric_area_r3, metric_area_r5, metric_area_r10)


def run_all_design(design_lst, rtl_fusion,\
                    training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict):
    
    metric_lst = []
    for design in design_lst:
        metric = zero_shot_infer_one_design(design, rtl_fusion,\
                                            training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict)
        metric_lst.append(metric)
    
    mean_slack_r_r1 = np.mean([metric[0][0] for metric in metric_lst])
    mean_slack_r_r3 = np.mean([metric[1][0] for metric in metric_lst])
    mean_slack_r_r5 = np.mean([metric[2][0] for metric in metric_lst])
    mean_slack_r_r10 = np.mean([metric[3][0] for metric in metric_lst])

    mean_slack_mape_r1 = np.mean([metric[0][1] for metric in metric_lst])
    mean_slack_mape_r3 = np.mean([metric[1][1] for metric in metric_lst])
    mean_slack_mape_r5 = np.mean([metric[2][1] for metric in metric_lst])
    mean_slack_mape_r10 = np.mean([metric[3][1] for metric in metric_lst])

    mean_area_r_r1 = np.mean([metric[8][0] for metric in metric_lst])
    mean_area_r_r3 = np.mean([metric[9][0] for metric in metric_lst])
    mean_area_r_r5 = np.mean([metric[10][0] for metric in metric_lst])
    mean_area_r_r10 = np.mean([metric[11][0] for metric in metric_lst])

    mean_area_mape_r1 = np.mean([metric[8][1] for metric in metric_lst])
    mean_area_mape_r3 = np.mean([metric[9][1] for metric in metric_lst])
    mean_area_mape_r5 = np.mean([metric[10][1] for metric in metric_lst])
    mean_area_mape_r10 = np.mean([metric[11][1] for metric in metric_lst])

    mean_pwr_r_r1 = np.mean([metric[4][0] for metric in metric_lst])
    mean_pwr_r_r3 = np.mean([metric[5][0] for metric in metric_lst])
    mean_pwr_r_r5 = np.mean([metric[6][0] for metric in metric_lst])
    mean_pwr_r_r10 = np.mean([metric[7][0] for metric in metric_lst])

    mean_pwr_mape_r1 = np.mean([metric[4][1] for metric in metric_lst])
    mean_pwr_mape_r3 = np.mean([metric[5][1] for metric in metric_lst])
    mean_pwr_mape_r5 = np.mean([metric[6][1] for metric in metric_lst])
    mean_pwr_mape_r10 = np.mean([metric[7][1] for metric in metric_lst])

    rpt_save_dir = f"./rpt/{date}"
    if not os.path.exists(rpt_save_dir):
        os.mkdir(rpt_save_dir)
    with open (f"{rpt_save_dir}/zero_shot_{epoch}.new.txt", 'w') as f:
        f.write('==== Design Summary ====\n')
        f.write('\n---- Slack Eval ----\n')
        f.write('R@1\n')
        f.write('R: '+str(mean_slack_r_r1)+'\n')
        f.write('MAPE: '+str(mean_slack_mape_r1)+'\n')
        f.write('R@3\n')
        f.write('R: '+str(mean_slack_r_r3)+'\n')
        f.write('MAPE: '+str(mean_slack_mape_r3)+'\n')
        f.write('R@5\n')
        f.write('R: '+str(mean_slack_r_r5)+'\n')
        f.write('MAPE: '+str(mean_slack_mape_r5)+'\n')
        f.write('R@10\n')
        f.write('R: '+str(mean_slack_r_r10)+'\n')
        f.write('MAPE: '+str(mean_slack_mape_r10)+'\n')
        f.write('\n---- Power Eval ----\n')
        f.write('R@1\n')
        f.write('R: '+str(mean_pwr_r_r1)+'\n')
        f.write('MAPE: '+str(mean_pwr_mape_r1)+'\n')
        f.write('R@3\n')
        f.write('R: '+str(mean_pwr_r_r3)+'\n')
        f.write('MAPE: '+str(mean_pwr_mape_r3)+'\n')
        f.write('R@5\n')
        f.write('R: '+str(mean_pwr_r_r5)+'\n')
        f.write('MAPE: '+str(mean_pwr_mape_r5)+'\n')
        f.write('R@10\n')
        f.write('R: '+str(mean_pwr_r_r10)+'\n')
        f.write('MAPE: '+str(mean_pwr_mape_r10)+'\n')
        f.write('\n---- Area Eval ----\n')
        f.write('R@1\n')
        f.write('R: '+str(mean_area_r_r1)+'\n')
        f.write('MAPE: '+str(mean_area_mape_r1)+'\n')
        f.write('R@3\n')
        f.write('R: '+str(mean_area_r_r3)+'\n')
        f.write('MAPE: '+str(mean_area_mape_r3)+'\n')
        f.write('R@5\n')
        f.write('R: '+str(mean_area_r_r5)+'\n')
        f.write('MAPE: '+str(mean_area_mape_r5)+'\n')
        f.write('R@10\n')
        f.write('R: '+str(mean_area_r_r10)+'\n')
        f.write('MAPE: '+str(mean_area_mape_r10)+'\n')

    with open (f"{rpt_save_dir}/zero_shot_{epoch}.new.txt", 'r') as f:
        for line in f:
            print(line)

    


    
                    


if __name__ == '__main__':
    clear = True

    global embed_dim
    embed_dim = 768

    global epoch
    
    for epoch in range(23, 50, 3):
        print(f"Epoch: {epoch}")
        epoch = 20
        if epoch < 0:
            continue

        model_save_dir = f"./pretrain_model/{date}"
        embed_save_dir = f'./embeds/{date}_{epoch}'
        if not os.path.exists(embed_save_dir):
            os.mkdir(embed_save_dir)

        rtl_fusion = torch.load(f"{model_save_dir}/rtl_fusion.{epoch}.pt")

        if (not os.path.exists(f"{embed_save_dir}/training_pool_embed_all.pkl")) or clear:
            training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict = calculate_training_pool(rtl_fusion)
        else:
            with open (f"{embed_save_dir}/training_pool_metric_all.json", 'r') as f:
                training_pool_metric_all = json.load(f)
            with open (f"{embed_save_dir}/training_pool_embed_all.pkl", 'rb') as f:
                training_pool_embed_all = pickle.load(f)
            with open (f"{embed_save_dir}/testing_design_metric_dict.json", 'r') as f:
                testing_design_metric_dict = json.load(f)

        with open("/home/coguest5/CircuitFusion/dataset/dataset_js/test_lst.json", 'r') as f:
            test_lst = json.load(f)

        run_all_design(test_lst, rtl_fusion,\
                        training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict)
        exit()
