import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from feature_extration.process_data import Paraser, get_cell
from torch.utils.data import DataLoader
from multiprocessing import Process
from timeit import default_timer as timer
from logger import Logger, AverageMeter, time_to_str
from models.cnngnn import create_cnngnn
from models.NetlistGNN import NetlistGNN


def combine_feat(gnnfeat, vitfeat, pos, afa):
    vit_part = vitfeat[:, pos, :]
    dif = torch.nn.L1Loss(reduction='sum')(vit_part, gnnfeat)#torch.nn.MSELoss(reduction='mean')(vit_part, gnnfeat)
    vitfeat[:, pos, :] = (1 - afa) * vit_part + afa * gnnfeat
    gnnfeat = (1 - afa) * gnnfeat +afa * vit_part
    return gnnfeat[0],vitfeat,dif

def get_afa(args,epoch):
    return np.linspace(args.feat_start, args.feat_end, num=args.epochs - args.start_epoch)[epoch-args.start_epoch]

if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    argp = Paraser()
    arg = argp.parser.parse_args()
    # 创建记录文件
    log = Logger()
    result_dir = os.path.join('results', arg.data_set, cur_time)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")


    dataset_train = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 1)
    dataset_val = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 0)
    train_graphs = dataset_train.graphs
    test_graphs = dataset_val.graphs
    train_names = []
    for name in dataset_train.name_list:
        train_names.append(name)
    test_names = []
    for name in dataset_val.name_list:
        test_names.append(name)
    data_loader_train = DataLoader(dataset_train, batch_size=1,shuffle=False)
    data_loader_val = DataLoader(dataset_val, batch_size=1,shuffle=False)

    print('##### GNN MODEL #####')
    in_node_feats = dataset_train.graphs[0][0].nodes['node'].data['hv'].shape[1]
    in_net_feats = dataset_train.graphs[0][0].nodes['net'].data['hv'].shape[1]
    in_pin_feats = dataset_train.graphs[0][0].edges['pinned'].data['he'].shape[1]
    out_node_feats = dataset_train.graphs[0][0].nodes['node'].data['label'].shape[1]
    edge_node_feats = dataset_train.graphs[0][0].edges['near'].data['he'].shape[1]
    config = {
        'N_LAYER': arg.gnn_layers,
        'NODE_FEATS': arg.node_feats,
        'NET_FEATS': arg.net_feats,
        'PIN_FEATS': arg.pin_feats,
        'EDGE_FEATS': arg.edge_feats,
    }
    modelgnn = NetlistGNN(
        in_node_feats=in_node_feats,
        in_net_feats=in_net_feats,
        in_pin_feats=in_pin_feats,
        in_edge_feats=1,
        n_target=out_node_feats,
        activation=arg.outtype,
        config=config,
        recurrent=arg.recurrent,
        topo_conv_type=arg.topo_conv_type,
        geom_conv_type=arg.geom_conv_type,
        agg_type=arg.agg_type,
        cat_raw=arg.cat_raw
    ).to(device)
    modelgnn.load_state_dict(torch.load(arg.gnn_checkpoint_root, map_location='cpu'))

    modelvit = create_cnngnn(log, vit_model_name = arg.vit_model_name, vit_layers = arg.vit_layers, in_channel = dataset_train.map_feature_num, nb_classes=out_node_feats).to(device)
    modelvit.load_state_dict(torch.load(arg.vit_checkpoint_root, map_location='cpu'))


    preprocess_save_path = "./generated_feature"
    os.makedirs(preprocess_save_path, exist_ok=True)
    os.makedirs(os.path.join(preprocess_save_path, "labels"), exist_ok=True)
    os.makedirs(os.path.join(preprocess_save_path, "features"), exist_ok=True)
    modelvit.eval()
    modelgnn.eval()
    with torch.no_grad():
        for feature_label_pos, ltg, name in zip(*(data_loader_val, test_graphs, test_names)):
            featurs, labels, pos_flatten = feature_label_pos
            # vit
            vit_feat = modelvit.forward_features(featurs)
            # print(vit_feat.shape)
            # gnn
            loss1 = 0.
            difs = 0.
            for hetero_graph, pos in zip(*(ltg, pos_flatten)):
                pos = pos[0, :, 0] * 256 + pos[0, :, 1]
                in_node_feat = hetero_graph.nodes['node'].data['hv']
                in_net_feat = hetero_graph.nodes['net'].data['hv']
                gnnfeat = modelgnn.forward_feaures(
                    in_node_feat=in_node_feat,
                    in_net_feat=in_net_feat,
                    in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                    in_edge_feat=hetero_graph.edges['near'].data['he'],
                    node_net_graph=hetero_graph,
                )
                gnnfeat, vit_feat, dif = combine_feat(gnnfeat, vit_feat, pos, arg.feat_end)
            vit_feat = vit_feat.reshape(modelvit.patch_embed.img_size[0],modelvit.patch_embed.img_size[1],modelvit.point_embed_dim).permute(2,0,1)
            labels = labels.reshape(modelvit.patch_embed.img_size[0],modelvit.patch_embed.img_size[1],-1).permute(2,0,1)
            np.save(os.path.join(preprocess_save_path, "labels", name), np.array(labels.to("cpu")))
            np.save(os.path.join(preprocess_save_path, "features", name), np.array(vit_feat.to("cpu")))

        for feature_label_pos, ltg, name in zip(*(data_loader_train, train_graphs, train_names)):
            featurs, labels, pos_flatten = feature_label_pos
            # vit
            vit_feat = modelvit.forward_features(featurs)
            # print(vit_feat.shape)
            # gnn
            loss1 = 0.
            difs = 0.
            for hetero_graph, pos in zip(*(ltg, pos_flatten)):
                pos = pos[0, :, 0] * 256 + pos[0, :, 1]
                in_node_feat = hetero_graph.nodes['node'].data['hv']
                in_net_feat = hetero_graph.nodes['net'].data['hv']
                gnnfeat = modelgnn.forward_feaures(
                    in_node_feat=in_node_feat,
                    in_net_feat=in_net_feat,
                    in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                    in_edge_feat=hetero_graph.edges['near'].data['he'],
                    node_net_graph=hetero_graph,
                )
                gnnfeat, vit_feat, dif = combine_feat(gnnfeat, vit_feat, pos, arg.feat_end)
            vit_feat = vit_feat.reshape(modelvit.patch_embed.img_size[0],modelvit.patch_embed.img_size[1],modelvit.point_embed_dim).permute(2,0,1)
            labels = labels.reshape(modelvit.patch_embed.img_size[0],modelvit.patch_embed.img_size[1],-1).permute(2,0,1)
            np.save(os.path.join(preprocess_save_path, "labels", name), np.array(labels.to("cpu")))
            np.save(os.path.join(preprocess_save_path, "features", name), np.array(vit_feat.to("cpu")))

