import os
import time
import argparse
import numpy as np
import torch
import torch.nn as nn
from feature_extration.process_data import Paraser, get_cell
from feature_extration.src.vit_dataset import get_vit_dataset
from torch.utils.data import DataLoader
from multiprocessing import Process
from timeit import default_timer as timer
from logger import Logger, AverageMeter, time_to_str
from models.cnngnn import create_cnngnn
from models.NetlistGNN import NetlistGNN


def combine_feat(gnnfeat, vitfeat, pos, afa):
    vit_part = vitfeat[:, pos, :]
    dif = torch.nn.L1Loss(reduction='sum')(vit_part, gnnfeat)#torch.nn.MSELoss(reduction='mean')(vit_part, gnnfeat)
    vitfeat[:, pos, :] = (1 - afa) * vit_part + afa * gnnfeat
    gnnfeat = (1 - afa) * gnnfeat +afa * vit_part
    return gnnfeat[0],vitfeat,dif

def get_afa(args,epoch):
    return np.linspace(args.feat_start, args.feat_end, num=args.epochs - args.start_epoch)[epoch-args.start_epoch]

def train_one_epoch(epoch,log,data_loader_train,data_loader_val,train_graphs,test_graphs,optimizer,modelvit,modelgnn,device,args):
    start_time = timer()
    log.write(f"----------------epoch={epoch}------------------")
    modelvit.train()
    total_loss_ = AverageMeter()
    for feature_label_pos in data_loader_train:
        featurs, labels = feature_label_pos
        #vit
        vit_feat = modelvit.forward_features(featurs)
        vit_pred = modelvit.forward_heads(vit_feat)
        loss2 = torch.nn.MSELoss(reduction='sum')(vit_pred,labels)
        total_loss = loss2
        optimizer.zero_grad()
        total_loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss_.update(total_loss.item(),1)
    message = '%s %6.0f |  %0.3f | %s\n' % ( \
        "train", epoch,
        total_loss_.avg,
        time_to_str((timer() - start_time), 'min'))
    # print(message)
    log.write(message)

    #test
    if epoch % 10 == 0:
        with torch.no_grad():
            modelvit.eval()
            total_loss_ = AverageMeter()
            for feature_label_pos in data_loader_val:
                featurs, labels = feature_label_pos
                # vit
                vit_feat = modelvit.forward_features(featurs)
                vit_pred = modelvit.forward_heads(vit_feat)
                loss2 = torch.nn.MSELoss(reduction='sum')(vit_pred, labels)
                total_loss = loss2 # 更新参数
                total_loss_.update(total_loss.item(), 1)
            message = '%s %6.0f |  %0.3f | %s\n' % ( \
                "test", epoch,
                total_loss_.avg,
                time_to_str((timer() - start_time), 'min'))
            log.write(message)



if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    argp = Paraser()
    arg = argp.parser.parse_args()
    # 创建记录文件
    log = Logger()
    result_dir = os.path.join('results', arg.data_set, cur_time)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")


    #dataset_train = get_cell(arg.data_root, arg.graph_save_root,[arg.train_list,arg.test_list], arg.lef_path, arg.place_def_root, arg.route_def_root, 1)
    #dataset_train= get_vit_dataset(arg.data_root, [arg.train_list, arg.test_list], 1)
    dataset_val = get_vit_dataset(arg.data_root,[arg.train_list,arg.test_list],0)
    dataset_train = dataset_val
    data_loader_train = DataLoader(dataset_train, batch_size=1,shuffle=False)
    data_loader_val = DataLoader(dataset_val, batch_size=1,shuffle=False)




    modelvit = create_cnngnn(log, vit_model_name = arg.vit_model_name, vit_layers = arg.vit_layers, in_channel = dataset_train.all_features.shape[1], nb_classes=dataset_train.all_labels.shape[1]).to(device)
    #optimizervit = torch.optim.Adam(modelvit.parameters(),lr=arg.vitlr,weight_decay=arg.vitweight_decay)
    #schedulervit = torch.optim.lr_scheduler.StepLR(optimizervit, 1, gamma=(1 - arg.vitlr_decay))

    optimizer = torch.optim.Adam(modelvit.parameters(), lr=arg.lr, weight_decay=arg.weight_decay)




    log.write('\n\nbegin_train: \n\n')
    for epoch in range(arg.start_epoch, arg.epochs):
        train_one_epoch(epoch,log,data_loader_train,data_loader_val,None,None,optimizer,modelvit,None,device,arg)

    torch.save(modelvit.state_dict(), os.path.join(result_dir, 'vit.pth'))










