import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import time
from timeit import default_timer as timer

from dataset.dataset import NetlistgnnDataset
# from utils.losses import build_loss
# from models.build_model import build_model
from utils.configs import Parser
#from models.build_model import build_model
from logger import Logger, AverageMeter, time_to_str
from math import cos, pi
import sys, os
from net.NetlistGNN import NetlistGNN


NODE_TOPO_FEAT = [0, 1, 2, 39, 40, 41]
NET_TOPO_FEAT = [0]


def main():
    argp = Parser()
    args = argp.parser.parse_args()
    seed = args.seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"
    cur_time = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime())
    log = Logger()
    args.model = "netlistgnn"
    result_dir = os.path.join(args.result_path, args.model, cur_time)
    checkpoint_path = os.path.join(args.result_path, args.model, 'checkpoint.pth')
    os.makedirs(result_dir, exist_ok=True)
    log_file_name = os.path.join(result_dir, 'loss.log')
    log.open(log_file_name, mode="w")

    train_dataset = NetlistgnnDataset(args.data_root, args.train_list, args)
    train_list_tuple_graph = train_dataset.graphs
    test_dataset = NetlistgnnDataset(args.data_root, args.test_list, args)
    test_list_tuple_graph = test_dataset.graphs
    in_node_feats = train_list_tuple_graph[0][0][1].nodes['node'].data['hv'].shape[1]
    in_net_feats = train_list_tuple_graph[0][0][1].nodes['net'].data['hv'].shape[1]
    in_pin_feats = train_list_tuple_graph[0][0][1].edges['pinned'].data['he'].shape[1]
    output_dim = train_list_tuple_graph[0][0][0].ndata['label'].shape[1]

    config = {
        'N_LAYER': args.layers,
        'NODE_FEATS': args.node_feats,
        'NET_FEATS': args.net_feats,
        'PIN_FEATS': args.pin_feats,
        'EDGE_FEATS': args.edge_feats,
    }
    model = NetlistGNN(
        in_node_feats=in_node_feats,
        in_net_feats=in_net_feats,
        in_pin_feats=in_pin_feats,
        output_dim=output_dim,
        in_edge_feats=1,
        n_target=output_dim,
        activation=args.outtype,
        config=config,
        recurrent=args.recurrent,
        topo_conv_type=args.topo_conv_type,
        geom_conv_type=args.geom_conv_type,
        agg_type=args.agg_type,
        cat_raw=args.cat_raw
    ).to(device)

    if not args.pretrain:
        if not os.path.exists(checkpoint_path):
            print("checkoint.pth not found")
            exit()
        else:
            print("loading model state dict")
            pretrain_dict = torch.load(checkpoint_path)
            model_dict = model.state_dict()
            pretrained_dict = {key: value for key, value in pretrain_dict.items() if
                               (key in model_dict and value.shape == model_dict[key].shape)}
            model.load_state_dict(pretrained_dict, strict=False)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=(1 - args.lr_decay))
    loss_f = nn.MSELoss()
    log.write(f' epochs = {args.epochs}     lr = {args.lr}    pretrain = {args.pretrain}  label=={args.label}')
    log.write('\n\nbegin_train: \n\n')

    for epoch in range(0, args.epochs):
        print(f'##### EPOCH {epoch} #####')
        print(f'\tLearning rate: {optimizer.state_dict()["param_groups"][0]["lr"]}')
        log.write(f"----------------epoch={epoch}----------------\n")

        def train(ltg):
            if args.trans:
                for p in model.net_readout_params:
                    p.train()
            else:
                model.train()
            t1 = timer()
            losses = []
            total_loss_ = AverageMeter()
            for j, graphs in enumerate(ltg):
                for (homo_graph, hetero_graph) in graphs:
                    homo_graph = homo_graph.to(device)
                    hetero_graph = hetero_graph.to(device)
                    optimizer.zero_grad()
                    in_node_feat = hetero_graph.nodes['node'].data['hv']
                    in_net_feat = hetero_graph.nodes['net'].data['hv']
                    if args.pos_code > 1e-5 and args.topo_geom != 'topo':
                        in_node_feat += args.pos_code * hetero_graph.nodes['node'].data['pos_code']
                    if args.topo_geom == 'topo':
                        in_node_feat = in_node_feat[:, NODE_TOPO_FEAT]
                        in_net_feat = in_net_feat[:, NET_TOPO_FEAT]
                    if args.add_pos:
                        in_node_feat = torch.cat([in_node_feat, homo_graph.ndata['pos']], dim=-1)
                    pred, _ = model.forward(
                        in_node_feat=in_node_feat,
                        in_net_feat=in_net_feat,
                        in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                        in_edge_feat=hetero_graph.edges['near'].data['he'],
                        node_net_graph=hetero_graph,
                    )
                    pred = pred * args.scalefac
                    batch_labels = homo_graph.ndata['label']
                    #print(batch_labels.shape)
                    loss = 0.
                    if args.label == "congestion":
                        loss = loss_f(pred[:,[0,1]], batch_labels[:,[0,1]])
                    elif args.label == "DRC":
                        loss = loss_f(pred[:,[2]], batch_labels[:,[2]])
                    elif args.label == "IR_drop":
                        loss = loss_f(pred[:,[3]], batch_labels[:,[3]])
                    elif args.label == "all":
                        loss = loss_f(pred, batch_labels)
                    elif args.label == "thermal":
                        loss = loss_f(pred, batch_labels)
                    losses.append(loss)
                one_graph_loss = sum(losses) / len(losses)
                one_graph_loss.backward()
                optimizer.step()
                losses.clear()
                total_loss_.update(one_graph_loss.item(), 1)
            scheduler.step()
            message = '%s %6.0f | %0.3f  | %s\n' % ( \
                "train", epoch,
                total_loss_.avg,
                time_to_str((timer() - t1), 'min'))
            log.write(message)

        def evaluate(ltg):
            model.eval()
            t1 = timer()
            losses = []
            total_loss_ = AverageMeter()
            with torch.no_grad():
                for j, graphs in enumerate(ltg):
                    for (homo_graph, hetero_graph) in graphs:
                        homo_graph = homo_graph.to(device)
                        hetero_graph = hetero_graph.to(device)
                        in_node_feat = hetero_graph.nodes['node'].data['hv']
                        in_net_feat = hetero_graph.nodes['net'].data['hv']
                        if args.pos_code > 1e-5 and args.topo_geom != 'topo':
                            in_node_feat += args.pos_code * hetero_graph.nodes['node'].data['pos_code']
                        if args.topo_geom == 'topo':
                            in_node_feat = in_node_feat[:, NODE_TOPO_FEAT]
                            in_net_feat = in_net_feat[:, NET_TOPO_FEAT]
                        if args.add_pos:
                            in_node_feat = torch.cat([in_node_feat, homo_graph.ndata['pos']], dim=-1)
                        pred, _ = model.forward(
                            in_node_feat=in_node_feat,
                            in_net_feat=in_net_feat,
                            in_pin_feat=hetero_graph.edges['pinned'].data['he'],
                            in_edge_feat=hetero_graph.edges['near'].data['he'],
                            node_net_graph=hetero_graph,
                        )
                        pred = pred * args.scalefac
                        batch_labels = homo_graph.ndata['label']
                        loss = 0.
                        if args.label == "congestion":
                            loss = loss_f(pred[:, [0, 1]], batch_labels[:, [0, 1]])
                        elif args.label == "DRC":
                            loss = loss_f(pred[:, [2]], batch_labels[:, [2]])
                        elif args.label == "IR_drop":
                            loss = loss_f(pred[:, [3]], batch_labels[:, [3]])
                        elif args.label == "all":
                            loss = loss_f(pred, batch_labels)
                        elif args.label == "thermal":
                            loss = loss_f(pred, batch_labels)
                        losses.append(loss)
                    one_graph_loss = sum(losses) / len(losses)
                    losses.clear()
                    total_loss_.update(one_graph_loss.item(), 1)
                message = '%s %6.0f | %0.3f  | %s\n' % ( \
                    "test", epoch,
                    total_loss_.avg,
                    time_to_str((timer() - t1), 'min'))
                log.write(message)


        train(train_list_tuple_graph)
        if epoch > 0 and epoch % 5 == 0:
            evaluate(test_list_tuple_graph)

    if (args.pretrain) and (not os.path.exists(checkpoint_path)):
        print("gengxin checkpoint")
        torch.save(model.state_dict(),checkpoint_path)
    print("complete")



if __name__ == "__main__":
    main()