import os
import time
import torch
import argparse
import pickle
import warnings
import numpy as np

warnings.filterwarnings("ignore")

from utils import Logger
from knowledge_generation import KnowledgeGenerator
from torch.optim import Adam
from configs.data_config import add_data_config
from configs.model_config import add_model_config
from configs.train_config import add_training_config
from datasets.load_data import load_directed_graph
from utils import seed_everything

def kd_node_cls_train(model, data_x, labels, train_idx, optimizer, node_size):
    model.train()
    initial_params = [(name, param.clone()) for name, param in model.named_parameters()]
        
    optimizer.zero_grad()
    loss_train, train_output = model(data_x, labels, train_idx, node_size)
    acc_train = accuracy(train_output[train_idx], labels[train_idx])
    loss_train.backward()
    optimizer.step()

    trained_params = [(name, param.clone()) for name, param in model.named_parameters()]
    for initial, trained in zip(initial_params, trained_params):
        if not (initial[1] != trained[1]).any():
            print(initial[0])

    return loss_train.item(), acc_train

def accuracy(output, labels):
    pred = output.max(1)[1].type_as(labels)
    correct = pred.eq(labels).double()
    correct = correct.sum()
    return (correct / len(labels)).item()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(add_help=False)
    add_data_config(parser)
    add_model_config(parser)
    add_training_config(parser)
    
    args = parser.parse_args()

    seed_everything(args.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    # device = torch.device('cuda:1' if (args.use_cuda and torch.cuda.is_available()) else 'cpu')

    PWD = os.path.dirname(os.path.realpath(__file__)) 
    t_path = os.path.join(PWD, 'trees','%s_%s.pickle'%(args.dataset, args.tree_depth))
    with open(t_path, 'rb') as fp:
        tree_data = pickle.load(fp)
    tree_node_size, interLayer_edgeMat, layer_edgeMat, _ = tree_data.values() 
    tree_height = len(tree_node_size) 
    
    dataset = load_directed_graph(args, name=args.data_name, root=args.data_root, k=args.data_dimension_k,
                                node_split=args.data_node_split, edge_split=args.data_edge_split,
                                node_split_id=args.data_node_split_id, edge_split_id=args.data_edge_split_id)
    data = dataset
    num_features = data.num_features 
    num_classes = data.num_classes 
    labels = data.y.to(device)
    data_x = data.x.to(device)
    train_idx = data.train_idx.to(device)
    val_idx = data.val_idx.to(device)
    test_idx = data.test_idx.to(device)
    edge_index = data.edge_index.to(device)

    log_dir = os.path.join("log", "kd", args.data_name, args.data_edge_split)
    logger_time = "log " + str(int(time.time())%17)
    logger_name = os.path.join(log_dir, logger_time + ".log")
    logger = Logger(logger_name)
    logger.info(f"datasets: {args.data_name}")
    logger.info(f"num_epochs: {args.num_epochs}, lr: {args.lr}, weight_decay: {args.weight_decay}")
    logger.info(f"dataset.x.shape: {dataset.x.shape}")
    label_info = max(dataset.y)+1 if dataset.y != None else -1
    logger.info(f"max(dataset.y)+1: {label_info}")

    model = KnowledgeGenerator(args, num_classes=num_classes, num_features=num_features, \
                               node_size=tree_node_size, device=device, logger=logger, \
                                interLayer_edgeMat=interLayer_edgeMat, layer_edgeMat=layer_edgeMat)

    normalize_times = args.normalize_times
    normalize_record = {"val_acc": []}

    optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    epochs = args.num_epochs
    # early_stop = early_stop

    total_epochs_time = []
    total_time = []
    for i in range(normalize_times):
        begin_t = time.time()
        if i == 0:
            normalize_times_st = time.time()
        else: 
            model = KnowledgeGenerator(args, num_classes=num_classes, num_features=num_features, \
                                        node_size=tree_node_size, device=device, logger=logger, \
                                        interLayer_edgeMat=interLayer_edgeMat, layer_edgeMat=layer_edgeMat)
            optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        
        model = model.to(device)
        labels = labels.to(device)

        t_total = time.time()
        best_val = 0.
        best_test = 0.
        stop = 0
        epochs_time = []
        for epoch in range(epochs):

            t = time.time()

            loss_train, acc_train = kd_node_cls_train(model, data_x, labels, train_idx, optimizer, \
                                                      tree_node_size)
            acc_val = model.evaluate(val_idx, labels)
            epoch_time = time.time() - t
            logger.info("Epoch: {:03d}, loss_train: {:.4f}, acc_train: {:.4f}, acc_val: {:.4f}, "
                                 "time: {:.4f}s\n".format(epoch+1, loss_train, acc_train, acc_val
                                                                          , epoch_time))
            epochs_time.append(epoch_time)
            if acc_val > best_val:
                best_val = acc_val
                stop = 0
            stop += 1

        if acc_val > best_val:
            best_val = acc_val

        if normalize_times == 1:
            logger.info("Optimization Finished!")
            logger.info("Total training time is: {:.4f}s".format(time.time() - t_total))
            logger.info(f'Best val: {best_val:.4f}')
        normalize_record["val_acc"].append(best_val)

        total_epochs_time += epochs_time
        total_time.append(time.time() - begin_t)
    
    if normalize_times > 1:
        logger.info("Optimization Finished!")
        logger.info("Total training time is: {:.4f}s".format(time.time() - normalize_times_st))      
        logger.info("Mean Val ± Std Val: {}±{}".format(
            round(np.mean(normalize_record["val_acc"]), 4),
            round(np.std(normalize_record["val_acc"], ddof=1), 4)))
        logger.info("Mean Epoch ± Std Epoch: {:.4f}s±{:.4f}s, Mean Total ± Std Total: {:.4f}s±{:.4f}s".format(
            np.mean(total_epochs_time), np.std(total_epochs_time),
            np.mean(total_time), np.std(total_time)))
        


