"""Train Knowledge Graph embeddings for link prediction."""

import argparse
from distutils.log import error
import json
import logging
import os
import sys
from numpy import dtype

import torch
from torch.functional import split
import torch.optim
from torch.utils.tensorboard import SummaryWriter

import models
import optimizers.regularizers as regularizers
from datasets.kg_dataset import KGDataset
from models import all_models
from optimizers.kg_optimizer import KGOptimizer
from utils.train import get_savedir, avg_both, format_metrics, count_params

isDebug = True if sys.gettrace() else False
if isDebug:
    os.environ['KGHOME'] = './'
    os.environ['LOG_DIR'] = "logs"
    os.environ['DATA_PATH'] = "data"
parser = argparse.ArgumentParser(
    description="Knowledge Graph Embedding"
)
parser.add_argument(
    "--dataset", default="WN18RR", choices=["FB15K", "WN", "WN18RR", "FB237", "YAGO3-10", "YAGO3-10-DR"],
    help="Knowledge Graph dataset"
)
parser.add_argument(
    "--model", default="RotE", choices=all_models, help="Knowledge Graph embedding model"
)
parser.add_argument(
    "--regularizer", choices=["N3", "F2", "DURA_RESCAL", "DURA_W", "DURA_UniBi_2","DURA_UniBi_3","DURA_QuatE"], default="N3", help="Regularizer"
)
parser.add_argument(
    "--reg", default=0, type=float, help="Regularization weight"
)
parser.add_argument(
    "--optimizer", choices=["Adagrad", "Adam", "SparseAdam", "Adamax", "AdamW"], default="Adagrad",
    help="Optimizer"
)
parser.add_argument(
    "--max_epochs", default=50, type=int, help="Maximum number of epochs to train for"
)
parser.add_argument(
    "--patience", default=10, type=int, help="Number of epochs before early stopping"
)
parser.add_argument(
    "--valid", default=3, type=float, help="Number of epochs before validation"
)
parser.add_argument(
    "--rank", default=1000, type=int, help="Embedding dimension"
)
parser.add_argument(
    "--batch_size", default=1000, type=int, help="Batch size"
)
parser.add_argument(
    "--neg_sample_size", default=50, type=int, help="Negative sample size, -1 to not use negative sampling"
)
parser.add_argument(
    "--init_size", default=1e-3, type=float, help=" >0:Initial embeddings' scale, -1: using xavier_"
)
parser.add_argument(
    "--learning_rate", default=1e-1, type=float, help="Learning rate"
)
parser.add_argument(
    "--gamma", default=0, type=float, help="Margin for distance-based losses"
)
parser.add_argument(
    "--dropout", default=0, type=float, help="Dropout rate"
)
parser.add_argument(
    "--bias", default="constant", type=str, choices=["constant", "learn", "none"], help="Bias type (none for no bias)"
)
parser.add_argument(
    "--dtype", default="double", type=str, choices=["single", "double"], help="Machine precision"
)
parser.add_argument(
    "--double_neg", action="store_true",
    help="Whether to negative sample both head and tail entities"
)
parser.add_argument(
    "--debug", action="store_true",
    help="Only use 1000 examples for debugging"
)
parser.add_argument(
    "--multi_c", action="store_true", help="Multiple curvatures per relation"
)
parser.add_argument(
    "--eval_on_train", action="store_true", help="Whether eval on train set"
)
parser.add_argument(
    "--ce_weight", action="store_true", help='Whether using ce weight'
)
parser.add_argument(
    "--batched_eval", action="store_true", help="Whether switch the eval model to False when evaluation, since when we do something to the tail entities, the original frame will just clash"
)
parser.add_argument(
    "--entity_norm", action="store_true", help="whether regularize the entities" 
)
parser.add_argument(
    "--rel_norm", action="store_true", help="whether regularize the relations" 
)
parser.add_argument(
    "--sta_scale", type=float, default=1, help='the scale for the cos'
)
parser.add_argument(
    "--label_smooth", type=float, default=0, help='label smooth factor'
)
parser.add_argument(
    "--add_self_loop", action="store_true", help="whether to add self loop"
)

    

def train(args):
    save_dir = get_savedir(args.model, args.dataset)

    # file logger
    logging.basicConfig(
        format="%(asctime)s %(levelname)-8s %(message)s",
        level=logging.INFO,
        datefmt="%Y-%m-%d %H:%M:%S",
        filename=os.path.join(save_dir, "train.log")
    )

    # tensorboard
    writer = SummaryWriter(save_dir, flush_secs=5) # with the log files 

    # stdout logger
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s")
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)
    logging.info("Saving logs in: {}".format(save_dir))
    
    # logging config args
    for k,v in sorted(vars(args).items()):
        logging.info('{}={}'.format(k, v))

    # create dataset
    dataset_path = os.path.join(os.environ["DATA_PATH"], args.dataset)
    dataset = KGDataset(dataset_path, args.debug, args.add_self_loop)
    args.sizes = dataset.get_shape() # [entity_nums, relation_nums, entity_nums] 

    # load data
    logging.info("\t " + str(dataset.get_shape()))
    train_examples_eval = dataset.get_examples("train_no_re") # no reciprocal, for evaluation on train set
    train_examples = dataset.get_examples("train") 
    valid_examples = dataset.get_examples("valid")
    test_examples = dataset.get_examples("test")
    filters = dataset.get_filters()

    # save config
    with open(os.path.join(save_dir, "config.json"), "w") as fjson:
        json.dump(vars(args), fjson)

    # ce weight
    if args.ce_weight:
        ce_weight = torch.Tensor(dataset.get_weight()).cuda()
    else:
        ce_weight = None 
    
    # create model
    model = getattr(models, args.model)(args) 
    total = count_params(model)
    logging.info("Total number of parameters {}".format(total))
    device = "cuda"
    model.to(device)

    # get optimizer
    regularizer = getattr(regularizers, args.regularizer)(args.reg)
    optim_method = getattr(torch.optim, args.optimizer)(model.parameters(), lr=args.learning_rate)
    optimizer = KGOptimizer(model, regularizer, optim_method, args.batch_size, args.neg_sample_size, bool(args.double_neg), weight=ce_weight, label_smooth=args.label_smooth)
    counter = 0
    best_mrr = None
    best_epoch = None
    logging.info("\t Start training")
    for step in range(args.max_epochs):

        # Train step
        model.train()
        train_loss = optimizer.epoch(train_examples) 
        logging.info("\t Epoch {} | average train loss: {:.4f}".format(step, train_loss))

        # Valid step
        model.eval()
        valid_loss = optimizer.calculate_valid_loss(valid_examples)
        logging.info("\t Epoch {} | average valid loss: {:.4f}".format(step, valid_loss))

        # write losses into tensorboard
        writer.add_scalar('train_loss', train_loss, step)
        writer.add_scalar('valid_loss', valid_loss, step)

        if (step + 1) % args.valid == 0:
            if args.eval_on_train:
                train_metrics = avg_both(*model.compute_metrics(train_examples_eval, filters, batch_size = 1000, batched_eval = args.batched_eval)) 
                logging.info(format_metrics(train_metrics, split='train'))
                # write metrics into tensor board
                writer.add_scalars('train_metrics', {'MRR':train_metrics['MRR'], 'hits@1':train_metrics['hits@[1,3,10]'][0],'hits@3':train_metrics['hits@[1,3,10]'][1],'hits@10':train_metrics['hits@[1,3,10]'][2]},step)

            valid_metrics = avg_both(*model.compute_metrics(valid_examples, filters, batch_size = 1000, batched_eval = args.batched_eval)) 
            logging.info(format_metrics(valid_metrics, split="valid"))

            # write metrics into tensor board
            writer.add_scalars('valid_metrics', {'MRR':valid_metrics['MRR'], 'hits@1':valid_metrics['hits@[1,3,10]'][0],'hits@3':valid_metrics['hits@[1,3,10]'][1],'hits@10':valid_metrics['hits@[1,3,10]'][2]},step)
            
            valid_mrr = valid_metrics["MRR"]
            if not best_mrr or valid_mrr > best_mrr:
                best_mrr = valid_mrr
                counter = 0
                best_epoch = step
                logging.info("\t Saving model at epoch {} in {}".format(step, save_dir))
                torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt"))
                model.cuda()
            else:
                counter += 1
                if counter == args.patience:
                    logging.info("\t Early stopping")
                    break
                elif counter == args.patience // 2:
                    pass

    logging.info("\t Optimization finished")
    if not best_mrr:
        torch.save(model.cpu().state_dict(), os.path.join(save_dir, "model.pt"))
    else:
        logging.info("\t Loading best model saved at epoch {}".format(best_epoch))
        model.load_state_dict(torch.load(os.path.join(save_dir, "model.pt")))
    model.cuda()
    model.eval()

    # Validation metrics
    valid_metrics = avg_both(*model.compute_metrics(valid_examples, filters, batch_size=1000, batched_eval = args.batched_eval))
    logging.info(format_metrics(valid_metrics, split="valid"))

    # Test metrics
    test_metrics = avg_both(*model.compute_metrics(test_examples, filters, batch_size=1000, batched_eval = args.batched_eval))
    logging.info(format_metrics(test_metrics, split="test"))

    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    # if called by others
    if __name__ != "__main__":
        return  test_metrics


if __name__ == "__main__":


    train(parser.parse_args())


