import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from feature_extration.process_data import Paraser, get_cell
from torch.utils.data import DataLoader
from multiprocessing import Process
from timeit import default_timer as timer
from logger import Logger, AverageMeter, time_to_str
from models.cnngnn import create_cnngnn
from models.NetlistGNN import NetlistGNN


def combine_feat(gnnfeat, vitfeat, pos, afa):
    vit_part = vitfeat[:, pos, :]
    dif = torch.nn.L1Loss(reduction='sum')(vit_part, gnnfeat)#torch.nn.MSELoss(reduction='mean')(vit_part, gnnfeat)
    vitfeat[:, pos, :] = (1 - afa) * vit_part + afa * gnnfeat
    gnnfeat = (1 - afa) * gnnfeat +afa * vit_part
    return gnnfeat[0],vitfeat,dif

def get_afa(args,epoch):
    return np.linspace(args.feat_start, args.feat_end, num=args.epochs - args.start_epoch)[epoch-args.start_epoch]

def train_one_epoch(epoch,log,data_loader_train,data_loader_val,train_graphs,test_graphs,optimizer,modelvit,modelgnn,device,args):
    start_time = timer()
    afa = get_afa(args,epoch)
    log.write(f"----------------epoch={epoch},afa={afa}-------------")
    modelvit.train()
    modelgnn.train()
    loss_gnn = AverageMeter()
    loss_vit = AverageMeter()
    loss_dif = AverageMeter()
    total_loss_ = AverageMeter()
    for feature_label_pos, ltg in zip(*(data_loader_train,train_graphs)):
        featurs, labels, pos_flatten = feature_label_pos
        featurs = featurs.to(device)
        labels = labels.to(device)
        #pos_flatten = pos_flatten.to(device)
        #vit
        vit_feat = modelvit.forward_features(featurs)
        #print(vit_feat.shape)
        #gnn
        loss1 = 0.
        difs = 0.
        for hetero_graph, pos in zip(*(ltg,pos_flatten)):
            hetero_graph = hetero_graph.to(device)
            pos = pos[0,:,0]*256+pos[0,:,1]
            in_node_feat = hetero_graph.nodes['node'].data['hv']
            in_net_feat = hetero_graph.nodes['net'].data['hv']
            gnnfeat = modelgnn.forward_feaures(
                in_node_feat=in_node_feat,
                in_net_feat=in_net_feat,
                in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                in_edge_feat=hetero_graph.edges['near'].data['he'],
                node_net_graph=hetero_graph,
            )
            gnnfeat,vit_feat,dif = combine_feat(gnnfeat,vit_feat,pos,afa)
            difs += dif
            gnn_pred = modelgnn.forward_heads(gnnfeat)
            gnn_label = hetero_graph.nodes['node'].data['label']
            loss1 += torch.nn.MSELoss(reduction='sum')(gnn_pred, gnn_label)
        vit_pred = modelvit.forward_heads(vit_feat)
        loss2 = torch.nn.MSELoss(reduction='sum')(vit_pred,labels)
        total_loss = args.l1*loss1 + args.l2*loss2 + args.l3*difs
        optimizer.zero_grad()
        total_loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        loss_gnn.update(loss1.item(),1)
        loss_vit.update(loss2.item(),1)
        loss_dif.update(difs.item(),1)
        total_loss_.update(total_loss.item(),1)
    message = '%s %6.0f | %0.3f | %0.3f | %0.3f | %0.3f | %s\n' % ( \
        "train", epoch,
        loss_gnn.avg,
        loss_vit.avg,
        loss_dif.avg,
        total_loss_.avg,
        time_to_str((timer() - start_time), 'min'))
    # print(message)
    log.write(message)

    #test
    if epoch % 10 == 0:
        torch.save(modelgnn.state_dict(), os.path.join(result_dir, 'gnn.pth'))
        torch.save(modelvit.state_dict(), os.path.join(result_dir, 'vit.pth'))
        with torch.no_grad():
            modelvit.eval()
            modelgnn.eval()
            loss_gnn = AverageMeter()
            loss_vit = AverageMeter()
            loss_dif = AverageMeter()
            total_loss_ = AverageMeter()
            for feature_label_pos, ltg in zip(*(data_loader_val, test_graphs)):
                featurs, labels, pos_flatten = feature_label_pos
                featurs = featurs.to(device)
                labels = labels.to(device)
                #pos_flatten = pos_flatten.to(device)
                # vit
                vit_feat = modelvit.forward_features(featurs)
                #print(vit_feat.shape)
                # gnn
                loss1 = 0.
                difs = 0.
                for hetero_graph, pos in zip(*(ltg, pos_flatten)):
                    hetero_graph = hetero_graph.to(device)
                    pos = pos[0, :, 0] * 256 + pos[0, :, 1]
                    in_node_feat = hetero_graph.nodes['node'].data['hv']
                    in_net_feat = hetero_graph.nodes['net'].data['hv']
                    gnnfeat = modelgnn.forward_feaures(
                        in_node_feat=in_node_feat,
                        in_net_feat=in_net_feat,
                        in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                        in_edge_feat=hetero_graph.edges['near'].data['he'],
                        node_net_graph=hetero_graph,
                    )
                    gnnfeat, vit_feat, dif = combine_feat(gnnfeat, vit_feat, pos, afa)
                    difs += dif
                    gnn_pred = modelgnn.forward_heads(gnnfeat)
                    gnn_label = hetero_graph.nodes['node'].data['label']
                    loss1 += torch.nn.MSELoss(reduction='sum')(gnn_pred, gnn_label)
                vit_pred = modelvit.forward_heads(vit_feat)
                loss2 = torch.nn.MSELoss(reduction='sum')(vit_pred, labels)
                total_loss = args.l1 * loss1 + args.l2 * loss2 + args.l3 * difs

                loss_gnn.update(loss1.item(), 1)
                loss_vit.update(loss2.item(), 1)
                loss_dif.update(difs.item(), 1)
                total_loss_.update(total_loss.item(), 1)
            message = '%s %6.0f | %0.3f | %0.3f | %0.3f | %0.3f | %s\n' % ( \
                "test", epoch,
                loss_gnn.avg,
                loss_vit.avg,
                loss_dif.avg,
                total_loss_.avg,
                time_to_str((timer() - start_time), 'min'))
            # print(message)
            log.write(message)



if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    argp = Paraser()
    arg = argp.parser.parse_args()
    # 创建记录文件
    log = Logger()
    result_dir = os.path.join('results', arg.data_set, cur_time)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")


    #dataset_train = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 1)
    dataset_val = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 0)
    dataset_train = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 1)
    train_graphs = dataset_train.graphs
    test_graphs = dataset_val.graphs
    data_loader_train = DataLoader(dataset_train, batch_size=1,shuffle=False)
    data_loader_val = DataLoader(dataset_val, batch_size=1,shuffle=False)

    print('##### GNN MODEL #####')
    in_node_feats = dataset_train.graphs[0][0].nodes['node'].data['hv'].shape[1]
    in_net_feats = dataset_train.graphs[0][0].nodes['net'].data['hv'].shape[1]
    in_pin_feats = dataset_train.graphs[0][0].edges['pinned'].data['he'].shape[1]
    out_node_feats = dataset_train.graphs[0][0].nodes['node'].data['label'].shape[1]
    edge_node_feats = dataset_train.graphs[0][0].edges['near'].data['he'].shape[1]
    config = {
        'N_LAYER': arg.gnn_layers,
        'NODE_FEATS': arg.node_feats,
        'NET_FEATS': arg.net_feats,
        'PIN_FEATS': arg.pin_feats,
        'EDGE_FEATS': arg.edge_feats,
    }
    modelgnn = NetlistGNN(
        in_node_feats=in_node_feats,
        in_net_feats=in_net_feats,
        in_pin_feats=in_pin_feats,
        in_edge_feats=1,
        n_target=out_node_feats,
        activation=arg.outtype,
        config=config,
        recurrent=arg.recurrent,
        topo_conv_type=arg.topo_conv_type,
        geom_conv_type=arg.geom_conv_type,
        agg_type=arg.agg_type,
        cat_raw=arg.cat_raw
    ).to(device)
    #optimizergnn = torch.optim.Adam(modelgnn.parameters(), lr=arg.gnnlr, weight_decay=arg.gnnweight_decay)
    #schedulergnn = torch.optim.lr_scheduler.StepLR(optimizergnn, 1, gamma=(1 - arg.gnnlr_decay))


    modelvit = create_cnngnn(log, vit_model_name = arg.vit_model_name, vit_layers = arg.vit_layers, in_channel = dataset_train.map_feature_num, nb_classes=out_node_feats).to(device)
    #optimizervit = torch.optim.Adam(modelvit.parameters(),lr=arg.vitlr,weight_decay=arg.vitweight_decay)
    #schedulervit = torch.optim.lr_scheduler.StepLR(optimizervit, 1, gamma=(1 - arg.vitlr_decay))

    #optimizer = torch.optim.Adam([{"params":modelgnn.parameters(),"params":modelvit.parameters()}], lr=arg.lr, weight_decay=arg.weight_decay)
    optimizer = torch.optim.Adam([{"params": modelgnn.parameters(), "lr": arg.gnnlr, "weight_decay":arg.gnnweight_decay},{"params": modelvit.parameters(), "lr": arg.vitlr, "weight_decay":arg.vitweight_decay}])




    log.write('\n\nbegin_train: \n\n')
    for epoch in range(arg.start_epoch, arg.epochs):
        train_one_epoch(epoch,log,data_loader_train,data_loader_val,train_graphs,test_graphs,optimizer,modelvit,modelgnn,device,arg)

    torch.save(modelgnn.state_dict(), os.path.join(result_dir, 'gnn.pth'))
    torch.save(modelvit.state_dict(), os.path.join(result_dir, 'vit.pth'))










