import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from feature_extration.process_data import Paraser, get_cell
from torch.utils.data import DataLoader
from multiprocessing import Process
from timeit import default_timer as timer
from logger import Logger, AverageMeter, time_to_str
from models.cnngnn import create_cnngnn
from models.NetlistGNN import NetlistGNN


def combine_feat(gnnfeat, vitfeat, pos, afa):
    vit_part = vitfeat[:, pos, :]
    dif = torch.nn.L1Loss(reduction='sum')(vit_part, gnnfeat)#torch.nn.MSELoss(reduction='mean')(vit_part, gnnfeat)
    vitfeat[:, pos, :] = (1 - afa) * vit_part + afa * gnnfeat
    gnnfeat = (1 - afa) * gnnfeat +afa * vit_part
    return gnnfeat[0],vitfeat,dif

def get_afa(args,epoch):
    return np.linspace(args.feat_start, args.feat_end, num=args.epochs - args.start_epoch)[epoch-args.start_epoch]

def train_one_epoch(epoch,log,data_loader_train,data_loader_val,train_graphs,test_graphs,optimizer,modelvit,modelgnn,device,args):
    start_time = timer()
    afa = get_afa(args,epoch)
    log.write(f"----------------epoch={epoch},afa={afa}-------------")
    #modelvit.train()
    modelgnn.train()
    loss_0 = AverageMeter()
    loss_1 = AverageMeter()
    loss_2 = AverageMeter()
    total_loss_ = AverageMeter()
    for ltg in train_graphs:
        #print(vit_feat.shape)
        #gnn
        loss0 = 0.
        loss1 = 0.
        loss2 = 0.
        for hetero_graph in ltg:
            in_node_feat = hetero_graph.nodes['node'].data['hv']
            in_net_feat = hetero_graph.nodes['net'].data['hv']
            gnnfeat = modelgnn.forward_feaures(
                in_node_feat=in_node_feat,
                in_net_feat=in_net_feat,
                in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                in_edge_feat=hetero_graph.edges['near'].data['he'],
                node_net_graph=hetero_graph,
            )
            gnnfeat = gnnfeat[0]
            gnn_pred = modelgnn.forward_heads(gnnfeat)
            gnn_label = hetero_graph.nodes['node'].data['label']
            loss0 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 0], gnn_label[:, 0])
            loss1 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 1], gnn_label[:, 1])
            loss2 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 2], gnn_label[:, 2])
        total_loss = loss1 + loss2 + loss0
        optimizer.zero_grad()
        total_loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        loss_0.update(loss0.item(),1)
        loss_1.update(loss1.item(),1)
        loss_2.update(loss2.item(),1)
        total_loss_.update(total_loss.item(),1)
    message = '%s %6.0f | %0.3f | %0.3f | %0.3f | %0.3f | %s\n' % ( \
        "train", epoch,
        loss_0.avg,
        loss_1.avg,
        loss_2.avg,
        total_loss_.avg,
        time_to_str((timer() - start_time), 'min'))
    # print(message)
    log.write(message)

    #test
    if epoch % 10 == 0:
        with torch.no_grad():
            #modelvit.eval()
            modelgnn.eval()
            loss_0 = AverageMeter()
            loss_1 = AverageMeter()
            loss_2 = AverageMeter()
            total_loss_ = AverageMeter()
            for ltg in train_graphs:
                # print(vit_feat.shape)
                # gnn
                loss0 = 0.
                loss1 = 0.
                loss2 = 0.
                for hetero_graph in ltg:
                    in_node_feat = hetero_graph.nodes['node'].data['hv']
                    in_net_feat = hetero_graph.nodes['net'].data['hv']
                    gnnfeat = modelgnn.forward_feaures(
                        in_node_feat=in_node_feat,
                        in_net_feat=in_net_feat,
                        in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                        in_edge_feat=hetero_graph.edges['near'].data['he'],
                        node_net_graph=hetero_graph,
                    )
                    gnnfeat = gnnfeat[0]
                    gnn_pred = modelgnn.forward_heads(gnnfeat)
                    gnn_label = hetero_graph.nodes['node'].data['label']
                    loss0 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 0], gnn_label[:, 0])
                    loss1 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 1], gnn_label[:, 1])
                    loss2 += torch.nn.MSELoss(reduction='sum')(gnn_pred[:, 2], gnn_label[:, 2])
                total_loss = loss1 + loss2 + loss0

                loss_0.update(loss0.item(), 1)
                loss_1.update(loss1.item(), 1)
                loss_2.update(loss2.item(), 1)
                total_loss_.update(total_loss.item(), 1)
            message = '%s %6.0f | %0.3f | %0.3f | %0.3f | %0.3f | %s\n' % ( \
                "test", epoch,
                loss_0.avg,
                loss_1.avg,
                loss_2.avg,
                total_loss_.avg,
                time_to_str((timer() - start_time), 'min'))
            # print(message)
            log.write(message)








if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    argp = Paraser()
    arg = argp.parser.parse_args()
    # 创建记录文件
    log = Logger()
    result_dir = os.path.join('results', arg.data_set+'_gnn', cur_time)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")


    #dataset_train = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 1)
    dataset_val = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 0)
    dataset_train = dataset_val
    train_graphs = dataset_train.graphs
    test_graphs = dataset_val.graphs
    #data_loader_train = DataLoader(dataset_train, batch_size=1,shuffle=False)
    #data_loader_val = DataLoader(dataset_val, batch_size=1,shuffle=False)

    print('##### GNN MODEL #####')
    in_node_feats = dataset_train.graphs[0][0].nodes['node'].data['hv'].shape[1]
    in_net_feats = dataset_train.graphs[0][0].nodes['net'].data['hv'].shape[1]
    in_pin_feats = dataset_train.graphs[0][0].edges['pinned'].data['he'].shape[1]
    out_node_feats = dataset_train.graphs[0][0].nodes['node'].data['label'].shape[1]
    edge_node_feats = dataset_train.graphs[0][0].edges['near'].data['he'].shape[1]
    config = {
        'N_LAYER': arg.gnn_layers,
        'NODE_FEATS': arg.node_feats,
        'NET_FEATS': arg.net_feats,
        'PIN_FEATS': arg.pin_feats,
        'EDGE_FEATS': arg.edge_feats,
    }
    modelgnn = NetlistGNN(
        in_node_feats=in_node_feats,
        in_net_feats=in_net_feats,
        in_pin_feats=in_pin_feats,
        in_edge_feats=1,
        n_target=out_node_feats,
        activation=arg.outtype,
        config=config,
        recurrent=arg.recurrent,
        topo_conv_type=arg.topo_conv_type,
        geom_conv_type=arg.geom_conv_type,
        agg_type=arg.agg_type,
        cat_raw=arg.cat_raw
    ).to(device)
    #optimizergnn = torch.optim.Adam(modelgnn.parameters(), lr=arg.gnnlr, weight_decay=arg.gnnweight_decay)
    #schedulergnn = torch.optim.lr_scheduler.StepLR(optimizergnn, 1, gamma=(1 - arg.gnnlr_decay))



    #optimizervit = torch.optim.Adam(modelvit.parameters(),lr=arg.vitlr,weight_decay=arg.vitweight_decay)
    #schedulervit = torch.optim.lr_scheduler.StepLR(optimizervit, 1, gamma=(1 - arg.vitlr_decay))

    optimizer = torch.optim.Adam(modelgnn.parameters(), lr=arg.gnnlr, weight_decay=arg.gnnweight_decay)

    arg.epochs=100
    log.write('\n\nbegin_train: \n\n')
    for epoch in range(arg.start_epoch, arg.epochs):
        train_one_epoch(epoch,log,None,None,train_graphs,test_graphs,optimizer,None,modelgnn,device,arg)

    torch.save(modelgnn.state_dict(), os.path.join(result_dir, 'gnn.pth'))










