"""
Get explanations:
Usage:
  get_explanations.py <model_path> <explainer> --rate <rate> [-s <root>]
  get_explanations.py -h | --help
  get_explanations.py --version

Options:
  -r --rate <rate>       Sets positive base rate for calibrating model if not calibrated.
  -s --save <root>       Specify whether and where to save the models and results.
  -h --help              Show this screen.
  --version              Show version.
"""
# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
from docopt import docopt


import os
import time
import numpy as np
import pandas as pd
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from ampligraph import latent_features
from ampligraph.datasets import load_fb15k_237, load_wn18rr, load_yago3_10, load_codex
from ampligraph.evaluation import evaluate_explainer, mrr_score, hits_at_n_score
from ampligraph.explanations import ExamplE
from ampligraph.utils import save_model, restore_model
from sklearn.calibration import calibration_curve
from scipy.special import expit
import matplotlib.pyplot as plt
import matplotlib
import pickle as pkl
from npencoder import NpEncoder
import json

def load_model_data(model_path, root, positive_rate=None, reliability_negatives_n=100):
    """Load model and dataset from model name and hyperparameters and dataset name.
    
       Parameters
       ----------
       model: str
           name of the model to be used.
       calibrate: float
           positive base rate for model calibration.
       
       Returns
       -------
       model: ampligraph.EmbeddingModel
           trained and calibrated model.
       X: dict
           dataset used for training model (train, test, valid).
    """
    
    if "FB15K-237" in model_path:
        X = load_fb15k_237()
    elif "WN18RR" in model_path:
        X = load_wn18rr()
    elif "YAGO" in model_path:
        X = load_yago3_10()
    elif "CODEX" in model_path:
        X = load_codex()    
    else:
        raise Exception(f"No such dataset - {model_path}")
    model_name = model_path.split('/')[-1]
    print_dataset_stats(X)    
    model = restore_model(model_path)
    if not model.is_calibrated:
        if not os.path.exists(f"{root}/{model_path}_calibrated"):
            model.calibrate(X['valid'], positive_base_rate=positive_rate)
            save_model(model, f"{root}/{model_name}_calibrated")
            ents = set(X['valid'][:,0]).union(set(X['valid'][:,2]))
            ents_rm = set(X['valid'][:reliability_negatives_n,0]).union(set(X['valid'][:reliability_negatives_n,2]))
            entities_for_corruption = list(ents.difference(ents_rm))[:reliability_negatives_n]
            
            corruptions = generate_corruptions_for_eval(X['valid'][:reliability_negatives_n], entities_for_corruption, corrupt_side='s+o')
            print(corruptions)
            labels = [1 for i in X['valid']] 
            labels.extend([0]*len(corruptions))
            
            model_not_cal = restore_model(model_path)
            scores = model_not_cal.predict(np.concatenate([X['valid'], np.asarray(corruptions)], 0))
            probas = model.predict_proba(np.concatenate([X['valid'], np.asarray(corruptions)], 0))
            plot_reliability_diagram(labels, scores, probas, positive_rate, save=root, name=model_name, setname='valid')

        else:
            model = restore_model(f"{model_path}_calibrated")
    return model, X


def generate_corruptions_for_eval(X, entities_for_corruption, corrupt_side='s+o'):
    """Generate corruptions for evaluation.

        Create corruptions (subject and object) for a given triple x, in compliance with the
        local closed world assumption (LCWA), as described in :cite:`nickel2016review`.

    Parameters
    ----------
    X : Tensor, shape [1, 3]
        Currently, a single positive triples that will be used to create corruptions.
    entities_for_corruption : Tensor
        All the entity IDs which are to be used for generation of corruptions.
    corrupt_side: string
        Specifies which side of the triple to corrupt:

        - 's': corrupt only subject.
        - 'o': corrupt only object
        - 's+o': corrupt both subject and object

    Returns
    -------
    out : Tensor, shape [n, 3]
        An array of corruptions for the triples for x.
        
    """

    if corrupt_side not in ['s+o', 's', 'o']:
        msg = 'Invalid argument value for corruption side passed for evaluation'
        raise ValueError(msg)

    if corrupt_side in ['s+o', 'o']:  # object is corrupted - so we need subjects as it is
        repeated_subjs = np.repeat(X[:,0],
        np.shape(entities_for_corruption)[0])
#        repeated_subjs = np.squeeze(repeated_subjs, 2)

    if corrupt_side in ['s+o', 's']:  # subject is corrupted - so we need objects as it is
        repeated_objs = np.repeat(X[:,2],
        np.shape(entities_for_corruption)[0])
#        repeated_objs = np.squeeze(repeated_objs, 2)

    repeated_relns = np.repeat(X[:,1],
    np.shape(entities_for_corruption)[0])
#    repeated_relns = np.squeeze(repeated_relns, 2)

    rep_ent = np.tile(entities_for_corruption, np.shape(X)[0])
#    rep_ent = np.squeeze(rep_ent, 0)

    if corrupt_side == 's+o':
        stacked_out = np.concatenate([np.stack([repeated_subjs, repeated_relns, rep_ent], 1),
                                 np.stack([rep_ent, repeated_relns, repeated_objs], 1)], 0)

    elif corrupt_side == 'o':
        stacked_out = np.stack([repeated_subjs, repeated_relns, rep_ent], 1)

    else:
        stacked_out = np.stack([rep_ent, repeated_relns, repeated_objs], 1)

    #out = np.reshape(np.transpose(stacked_out), (-1, 3))

    return stacked_out


def plot_reliability_diagram(labels, scores, probas, positive_rate, save="./", name='experiment', setname='valid'):
    # plot reliabilty diagram
    #plt.rcParams.update({'font.size': 24, 'axes.titlesize': 30})
    plt.rcParams.update({'font.size': 24, 'axes.titlesize': 30})

    fig = plt.figure(figsize=(18, 18))
    ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=1)
    ax2 = plt.subplot2grid((3, 1), (1, 0))

    ax1.plot([0, 1], [0, 1], "k:", lw=3, label="Perfectly calibrated")

    fraction_of_positives, mean_predicted_value = calibration_curve(labels,
                                                                    expit(scores), n_bins=10, strategy="quantile")
    ax1.plot(mean_predicted_value, fraction_of_positives, "s-", lw=3, label="Uncalibrated")

    fraction_of_positives, mean_predicted_value = calibration_curve(labels,
                                                                    probas, n_bins=10, strategy="quantile")
    ax1.plot(mean_predicted_value, fraction_of_positives, "s-", lw=3, label="Calibrated")
    ax2.hist(expit(scores), range=(0, 1), bins=10,
             histtype="step", lw=3, label="Uncalibrated")

    ax2.hist(probas, range=(0, 1), bins=10,
             histtype="step", lw=3, label="Calibrated")

    ax1.set_ylabel("Fraction of positives")
    ax1.set_ylim([-0.05, 1.05])
    ax1.legend(loc="lower right")
    ax1.set_title('Reliability diagram')

    ax2.set_xlabel("Mean predicted value")
    ax2.set_ylabel("Count")
    ax2.legend(loc="upper center", ncol=2)

    if save is not None:
        plt.tight_layout()
        plt.savefig(os.path.join(save, 'reliability_diagram_{}_{}.png'.format(name, setname)))
        plt.clf()
        plt.cla()
        plt.close()
    else:
        plt.tight_layout()
        plt.show()


def print_dataset_stats(X):
    """Report stats on the dataset (counts of triples unique entities etc...)"""
    print("#train triples: {}".format(len(X['train'])))
    print("#test triples: {}".format(len(X['test'])))
    print("#valid triples: {}".format(len(X['valid'])))    
    s = set(X['train'][:,0])
    s.update(X['train'][:,2])
    print("#entities: {}".format(len(s)))
    print("#predicates: {}".format(len(set(X['train'][:,1]))))


def get_explanations(model, X, model_name, root, explainer_name='example', n=-1):
    elapsed = -1    
    expls_path = f"{root}/explanations_{explainer_name}_{model_name}.pkl"
    if explainer_name == 'example':
        if os.path.exists(expls_path):
             with open(expls_path, 'rb') as f:
                explanations = pkl.loads(f.read())
        else:
            start = time.process_time()
            explainer = ExamplE(X, model)
            explanations = explainer.batch_explain(X['test'][:n], score=True)
            end = time.process_time()
            elapsed = end - start
            with open(expls_path, 'wb') as f:
                f.write(pkl.dumps(explanations))

    elif explainer_name == 'random':
        start = time.process_time()
        # random with same number of explaining triples and same predicate
        if os.path.exists(expls_path):
             with open(expls_path, 'rb') as f:
                explanations = pkl.loads(f.read())
        else:
            if os.path.exists(expls_path.replace('random', 'example')):
                with open(expls_path.replace('random', 'example'), 'rb') as f:
                    explanations = pkl.loads(f.read())
            else:
                explainer = ExamplE(X, model)
                explanations = explainer.batch_explain(X['test'][:n], score=True)
            random_explanations = []
            predicates = X['test'][:n,1]
            for i, expl in enumerate(explanations):
                predicate = predicates[i]
                triples_pool = X['train'][X['train'][:,1] == predicate]
                print(f'triples pool: {len(triples_pool)}, explanation size: {len(expl)}')
                rand_expl = np.random.choice(range(len(triples_pool)), len(expl), replace=False)
                random_explanations.append(triples_pool[rand_expl])
            explanations = np.asarray(random_explanations)
            end = time.process_time()
            elapsed = end - start
            with open(expls_path, 'wb') as f:
                f.write(pkl.dumps(explanations))
    else:
        raise Exception(f"No such explainer found: {explainer_name}")

    print(explanations[:5])
    return explanations, elapsed


if __name__ == "__main__":
    arguments = docopt(__doc__, version='Get explanations') 
    
    root = arguments['--save']
    model_path = arguments['<model_path>']
    positive_rate = float(arguments['--rate'])
    explainer_name = arguments['<explainer>']
    if root is not None and not os.path.exists(root):
        os.mkdir(root)
    print(arguments)
    model, X = load_model_data(model_path, root, positive_rate)
    model_name = model_path.split('/')[-1]
    explanations, elapsed = get_explanations(model, X, model_name, root, explainer_name)
    time_dict = {model_name:{'explainer':explainer_name, 'total_time_get_explanation': elapsed, 'time-per-triple': elapsed/len(X['test'])}}
    with open('results_times.json', 'a') as f:
        f.write(json.dumps(time_dict,  cls=NpEncoder))
    print(f"Total CPU time to get explanations: {elapsed}")
