import os
import json
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import time
from timeit import default_timer as timer

from dataset.dataset import CircuitNetDataset
# from utils.losses import build_loss
# from models.build_model import build_model
from utils.configs import Parser
from models.build_model import build_model
from logger import Logger, AverageMeter, time_to_str
from math import cos, pi
import sys, os
from dataset.vitgnn_dataset import VitgnnDataset


def train_one_epoch(epoch,log,data_loader_train,data_loader_val,optimizer,model,device,args):
    start_time = timer()
    log.write(f"----------------epoch={epoch}------------------\n")
    total_loss_ = AverageMeter()
    loss_1 = AverageMeter()
    loss_2 = AverageMeter()
    loss_3 = AverageMeter()
    for features, labels in data_loader_train:
        features = features.to(device).float()
        labels = labels.to(device).float()
        output = model(features)
        #print(output.shape)
        #print(labels.shape)
        if args.label == "all":
            l1 = torch.nn.MSELoss(reduction='mean')(output[:,[0],:,:], labels[:,[0],:,:])
            l2 = torch.nn.MSELoss(reduction='mean')(output[:, [1], :, :], labels[:, [1], :, :])
            l3 = torch.nn.MSELoss(reduction='mean')(output[:, [2], :, :], labels[:, [2], :, :])
            loss = l1 + l2 + l3
            loss_1.update(l1.item(), 1)
            loss_2.update(l2.item(), 1)
            loss_3.update(l3.item(), 1)
        else:
            loss =  torch.nn.MSELoss(reduction='mean')(output, labels)
        #print(loss)

        optimizer.zero_grad()
        loss.backward() 
        optimizer.step() 

        total_loss_.update(loss.item(), 1)
    message = '%s %6.0f |  %0.3f |  %0.3f |  %0.3f |  %0.3f | %s\n' % ( \
        "train", epoch,
        loss_1.avg,
        loss_2.avg,
        loss_3.avg,
        total_loss_.avg,
        time_to_str((timer() - start_time), 'min'))
    # print(message)
    log.write(message)
    #test
    if epoch % 5 == 0:
        with torch.no_grad():
            model.eval()
            total_loss_ = AverageMeter()
            loss_1 = AverageMeter()
            loss_2 = AverageMeter()
            loss_3 = AverageMeter()
            for features, labels in data_loader_val:
                features = features.to(device).float()
                labels = labels.to(device).float()
                output = model(features)
                if args.label == "all":
                    l1 = torch.nn.MSELoss(reduction='mean')(output[:,[0],:,:], labels[:,[0],:,:])
                    l2 = torch.nn.MSELoss(reduction='mean')(output[:, [1], :, :], labels[:, [1], :, :])
                    l3 = torch.nn.MSELoss(reduction='mean')(output[:, [2], :, :], labels[:, [2], :, :])
                    loss = l1 + l2 + l3
                    loss_1.update(l1.item(), 1)
                    loss_2.update(l2.item(), 1)
                    loss_3.update(l3.item(), 1)
                else:
                    loss =  torch.nn.MSELoss(reduction='mean')(output, labels)
                total_loss_.update(loss.item(), 1)
            message = '%s %6.0f |  %0.3f |  %0.3f |  %0.3f |  %0.3f | %s\n' % ( \
                "train", epoch,
                loss_1.avg,
                loss_2.avg,
                loss_3.avg,
                total_loss_.avg,
                time_to_str((timer() - start_time), 'min'))
            log.write(message)


def train():
    argp = Parser()
    args = argp.parser.parse_args()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    log = Logger()
    result_dir = os.path.join(args.result_path, args.model+"22", cur_time)
    checkpoint_path = os.path.join(args.result_path, args.model+"22", 'checkpoint.pth')
    os.makedirs(result_dir, exist_ok=True)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")
    #args.label = "thermal"
    args.vitgnndataroot = "/home/shiyu/zjb/cnngnn2/generated_feature/"
    train_dataset = VitgnnDataset(args.data_root, args.vitgnndataroot, args.train_list, args)
    in_feature_dim = train_dataset.features.shape[1]
    out_feature_dim = train_dataset.labels.shape[1]
    val_dataset = VitgnnDataset(args.data_root, args.vitgnndataroot, args.test_list, args)
    data_loader_train = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=False)
    data_loader_val = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    model = build_model(args, in_feature_dim, out_feature_dim).to(device)
    
    if (not args.pretrain) and (os.path.exists(checkpoint_path)):
        print("checkpoint loading")
        model2 = torch.load(checkpoint_path)
        pretrain_dict = model2.state_dict()
        model_dict = model.state_dict()
        pretrained_dict = {key: value for key, value in pretrain_dict.items() if
                           (key in model_dict and value.shape == model_dict[key].shape)}
        model.load_state_dict(pretrained_dict, strict=False)
        #model.load_state_dict(torch.load(checkpoint_path))
    

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    log.write(f' epochs = {args.epochs}     lr = {args.lr}    pretrain = {args.pretrain}')
    log.write('\n\nbegin_train: \n\n')
    for epoch in range(args.epochs+1):
        train_one_epoch(epoch,log,data_loader_train,data_loader_val,optimizer,model,device,args)

    if (args.pretrain) and (not os.path.exists(checkpoint_path)):
        torch.save(model,checkpoint_path)



if __name__ == "__main__":
    train()