import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
import time
from timeit import default_timer as timer

from dataset.dataset import NetlistgnnDataset
from net.naive import TraditionalGNNModel
# from utils.losses import build_loss
# from models.build_model import build_model
from utils.configs import Parser
#from models.build_model import build_model
from logger import Logger, AverageMeter, time_to_str
from math import cos, pi
import sys, os






def main():
    argp = Parser()
    args = argp.parser.parse_args()

    seed = args.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    log = Logger()
    result_dir = os.path.join(args.result_path, args.model, cur_time)
    checkpoint_path = os.path.join(args.result_path, args.model, 'checkpoint.pth')
    os.makedirs(result_dir, exist_ok=True)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")

    train_dataset = NetlistgnnDataset(args.data_root, args.train_list, args)
    train_list_tuple_graph = train_dataset.graphs
    test_dataset = NetlistgnnDataset(args.data_root, args.test_list, args)
    test_list_tuple_graph = test_dataset.graphs
    in_node_feats = train_list_tuple_graph[0][0][1].nodes['node'].data['hv'].shape[1]
    in_net_feats = train_list_tuple_graph[0][0][1].nodes['net'].data['hv'].shape[1]
    in_pin_feats = train_list_tuple_graph[0][0][1].edges['pinned'].data['he'].shape[1]
    output_dim = train_list_tuple_graph[0][0][0].ndata['label'].shape[1]
    arch = [in_node_feats, 400, 320, output_dim]

    print(f"Thermal label dimensions:")
    print(f"  - Output dimension: {output_dim}")
    print(f"  - Label shape: {train_list_tuple_graph[0][0][0].ndata['label'].shape}")
    print(f"  - Total nodes with labels: {train_list_tuple_graph[0][0][0].ndata['label'].shape[0]}")

    model = TraditionalGNNModel(
        model_type=args.model,
        arch_detail=arch,
        heads=int(args.heads),
        activation=args.outtype,
        scalefac=args.scalefac,
    ).to(device)
    if (not args.pretrain) and (os.path.exists(checkpoint_path)):
        print("checkpoint loading")
        model2 = torch.load(checkpoint_path)
        if isinstance(model2, dict):
            pretrain_dict = model2
        else:
            pretrain_dict = model2.state_dict()
        model_dict = model.state_dict()
        pretrained_dict = {key: value for key, value in pretrain_dict.items() if
                           (key in model_dict and value.shape == model_dict[key].shape)}
        model.load_state_dict(pretrained_dict, strict=False)
    else:
        print("checkpoint not load")

    loss_f = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=(1 - args.lr_decay))
    log.write(f' epochs = {args.epochs}     lr = {args.lr}    pretrain = {args.pretrain}')
    log.write('\n\nbegin_train: \n\n')

    for epoch in range(0, args.epochs):
        print(f'##### EPOCH {epoch} #####')
        print(f'\tLearning rate: {optimizer.state_dict()["param_groups"][0]["lr"]}')
        log.write(f"----------------epoch={epoch}----------------\n")

        def train(ltg):
            if args.trans:
                for p in model.net_readout_params:
                    p.train()
            else:
                model.train()
            t1 = timer()
            losses = []
            total_loss_ = AverageMeter()
            for j, graphs in enumerate(ltg):
                for (homo_graph, _) in graphs:
                    homo_graph = homo_graph.to(device)
                    optimizer.zero_grad()

                    pred = model.wholeforward(
                        g=homo_graph,
                        x=homo_graph.ndata['feat']
                    )
                    batch_labels = homo_graph.ndata['label']
                    loss = 0.
                    if args.label == "congestion":
                        loss = loss_f(torch.sum(pred[:, [0, 1]],1, keepdim=True), torch.sum(batch_labels[:, [0, 1]], 1,keepdim=True))
                    elif args.label == "DRC":
                        loss = loss_f(pred[:,[2]], batch_labels[:,[2]])
                    elif args.label == "IR_drop":
                        loss = loss_f(pred[:,[3]], batch_labels[:,[3]])
                    elif args.label == "all":
                        
                        loss = loss_f(pred, batch_labels)
                    elif args.label == "thermal":
                        loss = loss_f(pred, batch_labels)
                    #loss = loss_f(pred, batch_labels)
                    losses.append(loss)
                one_graph_loss = sum(losses)/ len(losses)
                one_graph_loss.backward()
                optimizer.step()
                losses.clear()
                total_loss_.update(one_graph_loss.item(), 1)
            scheduler.step()
            message = '%s %6.0f | %0.3f  | %s\n' % ( \
                "train", epoch,
                total_loss_.avg,
                time_to_str((timer() - t1), 'min'))
            log.write(message)

        def evaluate(ltg):
            model.eval()
            t1 = timer()
            losses = []
            total_loss_ = AverageMeter()
            with torch.no_grad():
                for j, graphs in enumerate(ltg):
                    for (homo_graph, _) in graphs:
                        homo_graph = homo_graph.to(device)
                        pred = model.wholeforward(
                            g=homo_graph,
                            x=homo_graph.ndata['feat']
                        )
                        batch_labels = homo_graph.ndata['label']
                        loss = 0.
                        if args.label == "congestion":
                            loss = loss_f(torch.sum(pred[:, [0, 1]],1, keepdim=True), torch.sum(batch_labels[:, [0, 1]], 1,keepdim=True))
                        elif args.label == "DRC":
                            loss = loss_f(pred[:,[2]], batch_labels[:,[2]])
                        elif args.label == "IR_drop":
                            loss = loss_f(pred[:,[3]], batch_labels[:,[3]])
                        elif args.label == "all":
                            loss = loss_f(pred, batch_labels)
                        elif args.label == "thermal":
                            loss = loss_f(pred, batch_labels)
                        losses.append(loss)
                    one_graph_loss = sum(losses) / len(losses)
                    losses.clear()
                    total_loss_.update(one_graph_loss.item(), 1)
                scheduler.step()
                message = '%s %6.0f | %0.3f  | %s\n' % ( \
                    "test", epoch,
                    total_loss_.avg,
                    time_to_str((timer() - t1), 'min'))
                log.write(message)

        train(train_list_tuple_graph)
        if epoch > 0 and epoch % 5 == 0:
            evaluate(test_list_tuple_graph)

    if (args.pretrain) and (not os.path.exists(checkpoint_path)):
        print("gengxin checkpoint")
        torch.save(model, checkpoint_path)
    print("complete")



if __name__ == "__main__":
    main()