import argparse
import os, pickle, json
import numpy as np
import ruamel_yaml as yaml
from accelerate import Accelerator
import torch
from transformers.optimization import (
    AdamW,
    get_polynomial_decay_schedule_with_warmup,
)

from dataset_proc.load_dataset import load_train_valid_dataset, load_test_dataset,MyDataset, TextEmbedDataset

from models.model_pretrain import RTL_Fusion
from utils.eval import regression_metrics
from xgboost import XGBRegressor
## import mlp model
from sklearn.neural_network import MLPRegressor

from accelerate import DistributedDataParallelKwargs

from torch.utils.tensorboard import SummaryWriter  


date ='pretrain_align_sync_0817'
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

def get_design_metric(design_lst):
    #### get testing design metric ####
    design_metric_dict = {}
    for design in design_lst:
        ep_metric_all = []
        with open (f"/home/coguest5/ckt_enc2/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        for ep in reg_lst:
            if not os.path.exists(f"/home/coguest5/ckt_enc2/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                continue
            if not os.path.exists(f"/home/coguest5/ckt_enc2/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                continue

            ### PPA label ###
            with open (f"/home/coguest5/ckt_enc2/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
                cone_ppa_dct = json.load(f)
            ep_metric_all.append([design, ep, cone_ppa_dct['slack'], cone_ppa_dct['pwr'], cone_ppa_dct['area']])
        design_metric_dict[design] = ep_metric_all

    return design_metric_dict

def calculate_training_pool(rtl_fusion):
    #### get testing design metric ####
    testing_design_metric_dict = {}
    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/test_lst.json", 'r') as f:
        test_lst = json.load(f)
    for design in test_lst:
        ep_metric_all = []
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        for ep in reg_lst:
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                continue
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                continue

            ### PPA label ###
            with open (f"/home/coguest5/CircuitFusion/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
                cone_ppa_dct = json.load(f)
            ep_metric_all.append([design, ep, cone_ppa_dct['slack'], cone_ppa_dct['pwr'], cone_ppa_dct['area']])
        testing_design_metric_dict[design] = ep_metric_all
    with open (f"{embed_save_dir}/testing_design_metric_dict.json", 'w') as f:
        json.dump(testing_design_metric_dict, f, indent=4)


    #### get training pool metric ####
    training_pool_metric_all = []
    with open("/home/coguest5/CircuitFusion//dataset/dataset_js/train_lst.json", 'r') as f:
        train_lst = json.load(f)
    for design in train_lst:
        with open (f"/home/coguest5/CircuitFusion/data_collect/label/ep_lst/{design}.json", 'r') as f:
            reg_lst = json.load(f)
        for ep in reg_lst:
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_ori/{design}/{ep}.txt"):
                continue
            if not os.path.exists(f"/home/coguest5/CircuitFusion/text_enc/llm_extra/rtl_func_pos/{design}/{ep}.txt"):
                continue

            ### PPA label ###
            with open (f"/home/coguest5/CircuitFusion/data_collect/label/ppa/cone_pwr_area/{design}/{ep}.json", 'r') as f:
                cone_ppa_dct = json.load(f)
            ep_lst = [design, ep, cone_ppa_dct['slack'], cone_ppa_dct['pwr'], cone_ppa_dct['area']]
            training_pool_metric_all.append(ep_lst)
            print(ep_lst)
    
    print(len(training_pool_metric_all))
    with open (f"{embed_save_dir}/training_pool_metric_all.json", 'w') as f:
        json.dump(training_pool_metric_all, f)

    #### get training pool embed ####

    accelerator.print("Loading Dataset Training Pool ...")
    train_rtl_loader = load_train_valid_dataset(batch_size=32, train_valid="train")
    accelerator.print("Dataset Loaded!")
    (rtl_fusion, train_rtl_loader) = accelerator.prepare(rtl_fusion, train_rtl_loader)
    graph_ori_lader_train, _, _, summary_loader_train, text_loader_train = train_rtl_loader

    training_pool_embed_all = np.empty((0, embed_dim))
    for idx, data in enumerate(zip(graph_ori_lader_train, summary_loader_train, text_loader_train)):
        graph_data, bactched_summary, text_data = data
        summary_data = bactched_summary[0]
        fusion_emb = rtl_fusion((graph_data, summary_data, text_data), mode='infer')
        fusion_emb = fusion_emb.cpu().detach().numpy()
        training_pool_embed_all = np.append(training_pool_embed_all, fusion_emb, axis=0)
    
    print(training_pool_embed_all.shape)
    
    with open (f"{embed_save_dir}/training_pool_embed_all.pkl", 'wb') as f:
        pickle.dump(training_pool_embed_all, f)
    
    return training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict





def zero_shot_infer_one_design(design, rtl_fusion,\
                               training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict):
    print('Current Design: ', design)
    test_loader = load_test_dataset(design, batch_size=16)
    graph_loader, summary_loader, text_loader = test_loader
    design_embed_all = np.empty((0, embed_dim))
    for _, data in enumerate(zip(graph_loader, summary_loader, text_loader)):
        graph_data, bactched_summary, text_data = data
        summary_data = bactched_summary[0]
        fusion_emb = rtl_fusion((graph_data, summary_data, text_data), mode='infer')
        fusion_emb = fusion_emb.cpu().detach().numpy()
        design_embed_all = np.append(design_embed_all, fusion_emb, axis=0)


    val_r1, val_r5 = 0, 0
    top_k = 5
    for idx, test_ep_embed in enumerate(design_embed_all):
        ### calculate distance scores ###
        scores = (test_ep_embed @ training_pool_embed_all.T) * 100
        if design in train_lst:
            top_k_idx = np.argsort(scores)[::-1][:top_k+1]
            ##remove the highest score
            top_k_idx = top_k_idx[1:]
        else:
            top_k_idx = np.argsort(scores)[::-1][:top_k]
        
        if cmd == 'tns':
            pred_r1 = training_pool_metric_all[top_k_idx[0]][2]
            pred_r5 = np.mean([training_pool_metric_all[top_k_idx[i]][2] for i in range(top_k)])
        elif cmd == 'pwr':
            pred_r1 = training_pool_metric_all[top_k_idx[0]][3]
            pred_r5 = np.mean([training_pool_metric_all[top_k_idx[i]][3] for i in range(top_k)])
        elif cmd == 'area':
            pred_r1 = training_pool_metric_all[top_k_idx[0]][4]
            pred_r5 = np.mean([training_pool_metric_all[top_k_idx[i]][4] for i in range(top_k)])
        
        val_r1 += pred_r1
        val_r5 += pred_r5


    avg_embed = np.mean(np.array(design_embed_all), axis=0)
    print(avg_embed.shape)
        

    ### design-level graph feature
    with open(f"/home/coguest5/CircuitFusion/data_collect/dataset/rtl_graph/feat/ori/{design}_feat.json", 'r') as f:
        feat = json.load(f)

    feat.extend(list(avg_embed))
    feat.extend([val_r1, val_r5])

    ### get ppa label
    with open (f"/home/coguest5/CircuitFusion/data_collect/label/ppa/json/{design}/ppa.json", 'r') as f:
        cone_ppa_dct = json.load(f)
    if cmd == 'tns':
        real = cone_ppa_dct['tns']
    elif cmd == 'pwr':
        real = cone_ppa_dct['pwr']
    elif cmd == 'area':
        real = cone_ppa_dct['area']
    


    return feat, real
    
    # return (metric_slack_r1, metric_slack_r5, metric_pwr_r1, metric_pwr_r5, metric_area_r1, metric_area_r5)


def run_all_design(design_lst, rtl_fusion,\
                    training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict):
    
    feat_lst, real_lst = [], []
    for design in design_lst:
        feat, real = zero_shot_infer_one_design(design, rtl_fusion,\
                                            training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict)
        feat_lst.append(feat)
        real_lst.append(real)
    
    feat_all = np.array(feat_lst)
    real_all = np.array(real_lst)

    print(feat_all.shape, real_all.shape)
    # model = MLPRegressor(hidden_layer_sizes=(128, 32), max_iter=1000)

    model = XGBRegressor(n_estimators=100, max_depth=30)
    model.fit(feat_all, real_all)

    with open(f"{ft_model_save_dir}/{cmd}_xgb_model.pkl", 'wb') as f:
        pickle.dump(model, f)
    

def test(design_lst, rtl_fusion,\
                    training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict):
    
    with open(f"{ft_model_save_dir}/{cmd}_xgb_model.pkl", 'rb') as f:
        model = pickle.load(f)
    
    feat_lst, real_lst = [], []
    for design in design_lst:
        feat, real = zero_shot_infer_one_design(design, rtl_fusion,\
                                            training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict)
        feat_lst.append(feat)
        real_lst.append(real)   

    pred_lst = model.predict(np.array(feat_lst))
    
    print(f'\n\n=== Few-shot Inference Results ({cmd}):===')
    regression_metrics(pred_lst, real_lst)

    


        





    
                    


if __name__ == '__main__':

    global cmd
    cmd = 'area'

    global embed_dim
    embed_dim = 768

    epoch = 20
    model_save_dir = f"./pretrain_model/{date}"
    ft_model_save_dir = f"./finetune_model/{date}"
    embed_save_dir = f'./embeds/{date}_{epoch}'
    if not os.path.exists(embed_save_dir):
        os.mkdir(embed_save_dir)

    ### load pretrained model ###
    rtl_fusion = torch.load(f"{model_save_dir}/rtl_fusion.{epoch}.pt")


    ### load pre-computed training pool ###
    if not os.path.exists(f"{embed_save_dir}/training_pool_embed_all.pkl"):
        training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict = calculate_training_pool(rtl_fusion)
    else:
        with open (f"{embed_save_dir}/training_pool_metric_all.json", 'r') as f:
            training_pool_metric_all = json.load(f)
        with open (f"{embed_save_dir}/training_pool_embed_all.pkl", 'rb') as f:
            training_pool_embed_all = pickle.load(f)
        with open (f"{embed_save_dir}/testing_design_metric_dict.json", 'r') as f:
            testing_design_metric_dict = json.load(f)
    
    ### zero-shot inference ###
    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/train_lst.json", 'r') as f:
        train_lst = json.load(f)
    # with open("/home/coguest5/CircuitFusion/dataset/dataset_js/sft_ft_lst_4.json", 'r') as f:
    # with open("/home/coguest5/CircuitFusion/dataset/dataset_js/sft_ft_lst.json", 'r') as f:
    # with open("/home/coguest5/CircuitFusion/dataset/dataset_js/sft_ft_lst_16.json", 'r') as f:
    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/sft_ft_lst_all.json", 'r') as f:
        sft_lst = json.load(f)
    sft_design_metric_dict = get_design_metric(sft_lst)
    with open("/home/coguest5/CircuitFusion/dataset/dataset_js/test_lst.json", 'r') as f:
        test_lst = json.load(f)
    testing_design_metric_dict = get_design_metric(test_lst)

    run_all_design(sft_lst, rtl_fusion,\
                    training_pool_metric_all, training_pool_embed_all, sft_design_metric_dict)
    test(test_lst, rtl_fusion,\
        training_pool_metric_all, training_pool_embed_all, testing_design_metric_dict)    