import os
import json
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from timeit import default_timer as timer

from dataset.dataset import CircuitNetDataset
# from utils.losses import build_loss
# from models.build_model import build_model
from utils.configs import Parser
from models.build_model import build_model
from logger import Logger, AverageMeter, time_to_str
from math import cos, pi
import sys, os
from dataset.vitgnn_dataset import VitgnnDataset


def train_one_epoch(epoch,log,data_loader_train,data_loader_val,optimizer,model,device,args):
    start_time = timer()
    log.write(f"----------------epoch={epoch}------------------\n")
    total_loss_ = AverageMeter()
    for features, labels in data_loader_train:
        features = features.to(device).float()
        labels = labels.to(device).float()
        output = model(features)
        print(output.shape)
        print(labels.shape)
        if args.label == "congestion":
            output = torch.sum(output, 1, keepdim=True)
            labels = torch.sum(labels, 1, keepdim=True)
        loss =  torch.nn.MSELoss(reduction='mean')(output, labels)
        print(loss)

        optimizer.zero_grad()
        loss.backward()  # 反向传播
        optimizer.step()  # 更新参数

        total_loss_.update(loss.item(), 1)
    message = '%s %6.0f |  %0.3f | %s\n' % ( \
        "train", epoch,
        total_loss_.avg,
        time_to_str((timer() - start_time), 'min'))
    # print(message)
    log.write(message)
    #test
    if epoch % 10 == 0:
        with torch.no_grad():
            model.eval()
            total_loss_ = AverageMeter()
            for features, labels in data_loader_val:
                features = features.to(device).float()
                labels = labels.to(device).float()
                output = model(features)
                loss = torch.nn.MSELoss(reduction='mean')(output, labels)
                total_loss_.update(loss.item(), 1)
            message = '%s %6.0f |  %0.3f | %s\n' % ( \
                "test", epoch,
                total_loss_.avg,
                time_to_str((timer() - start_time), 'min'))
            log.write(message)


def train():
    argp = Parser()
    args = argp.parser.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    log = Logger()
    result_dir = os.path.join(args.result_path, args.model, cur_time)
    checkpoint_path = os.path.join(args.result_path, args.model, 'checkpoint.pth')
    os.makedirs(result_dir, exist_ok=True)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")

    train_dataset = CircuitNetDataset(args.data_root, args.train_list, args)
    in_feature_dim = train_dataset.feature_dim
    out_feature_dim = train_dataset.label_dim
    val_dataset = CircuitNetDataset(args.data_root, args.test_list, args)
    data_loader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    data_loader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    model = build_model(args, in_feature_dim, out_feature_dim).to(device)
    '''
    if (not args.pretrain) and (os.path.exists(checkpoint_path)):
        print("checkpoint loading")
        model2 = torch.load(checkpoint_path)
        pretrain_dict = model2.state_dict()
        model_dict = model.state_dict()
        pretrained_dict = {key: value for key, value in pretrain_dict.items() if
                           (key in model_dict and value.shape == model_dict[key].shape)}
        model.load_state_dict(pretrained_dict, strict=False)
        #model.load_state_dict(torch.load(checkpoint_path))
    '''

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    log.write(f' epochs = {args.epochs}     lr = {args.lr}    pretrain = {args.pretrain}')
    log.write('\n\nbegin_train: \n\n')
    for epoch in range(args.epochs+1):
        train_one_epoch(epoch,log,data_loader_train,data_loader_val,optimizer,model,device,args)

    if (args.pretrain) and (not os.path.exists(checkpoint_path)):
        torch.save(model,checkpoint_path)



if __name__ == "__main__":
    train()