"""
Measure probability difference:
Usage:
  measure_diff.py <model_path> <dataset> <triple_selection> -r <rate> --gpu <gpu> [-s <root>]
  measure_diff.py -h | --help
  measure_diff.py --version

Options:
  -s --save <root>       Specify whether and where to save the models and results.
  -r --rate <rate>       Specify positive base rate for calibrating model.
  -h --help              Show this screen.
  --version              Show version.
"""
# Copyright 2022 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
from docopt import docopt
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import ampligraph
from ampligraph import latent_features
from ampligraph.datasets import load_fb15k_237, load_wn18rr, load_yago3_10
from ampligraph.evaluation import evaluate_explainer, mrr_score, hits_at_n_score
from ampligraph.explanations import ExamplE
from ampligraph.utils import save_model, restore_model
from sklearn.calibration import calibration_curve
from scipy.special import expit
import matplotlib.pyplot as plt
import matplotlib
import pickle as pkl
from ampligraph.evaluation import remove_triples, get_calibration_set 
from ampligraph.compat import BACK_COMPAT_MODELS

from get_explanations import load_model_data
from npencoder import NpEncoder

from ampligraph.evaluation import mrr_score
from ampligraph.utils import restore_model
from ampligraph.datasets import load_wn18rr
import numpy as np
import pandas as pd
from ampligraph.evaluation import mrr_score, hits_at_n_score, mr_score

def _measure_proba_diff_from_paths(model_path, org_model_path, explanation_path, predictions_path, X, root, positive_rate, gpu, triple_selection):
    model_org = restore_model(org_model_path)
    model = restore_model(model_path)
    if not model.is_calibrated:
        model.calibrate(X['valid'], positive_base_rate=positive_rate)
    df = pd.read_csv(predictions_path)
    if triple_selection is None:
        print("Index of target triple is not specified, assuming that it is for highest scoring triple that is not circular")
        triple_selection = 'highest_nc'
    if triple_selection == 'highest':
        index = df.iloc[0]['index']
    elif triple_selection == 'highest_nc':
        index = df[df['s']!=df['o']][0]['index']
    elif isinstance(triple_selection, int):
        index = triple_selection
    else:
        raise Exception(f"Not such selection for triple_selection, {triple_selection}")
    predictions = pd.read_csv(predictions_path)
    with open(explanation_path, 'rb') as f:
        explanations = pkl.loads(f.read())
    measures = _measure_proba_diff(model, model_org, explanations, predictions, X, index)
    return measures


def _measure_proba_diff(model, model_org, explanations, df, X, index=None, n_hit=1):
    if index is None:
        print("Index of target triple is not specified, assuming that it is for highest scoring triple that is not circular")
        index = df[df['s']!=df['o']]['index']
    print(f"Target Triple: {df.iloc[index]}")
    print(f"Probability: {df.iloc[index]['pred']}")
    tt = np.asarray([X['test'][df.iloc[index]['index']]])
    ind = df.iloc[index]['index']
    pre_proba = model_org.predict_proba(tt)[0]
    print(pre_proba)
    print(df.iloc[index], tt)
    post_proba = model.predict_proba(tt)[0]
    print(f"Explanation: \nNumber of influential examples: {len(explanations[index])}")
    for i,e in enumerate(explanations[index]):
        print(f"{i}: <{e[0]}, {e[1]}, {e[2]}>, distance: {e[3]:.4}")
    print(f"Probability after retraining: {post_proba}")
    mrr = mrr_score(model.evaluate(tt))
    mrr_org = mrr_score(model_org.evaluate(tt))
    mr = mr_score(model.evaluate(tt))
    mr_org = mr_score(model_org.evaluate(tt))
    hits = hits_at_n_score(model.evaluate(tt), n_hit)
    hits_org = hits_at_n_score(model_org.evaluate(tt), n_hit)
    measures = {'mrr':mrr, 'mrr_org':mrr_org, 'mrr_diff': mrr_org-mrr, 'mr_score':mr, 'mr_org':mr_org, 'mr_diff':mr_org-mr,
                f'hit{n_hit}':hits, f'hits{n_hit}_org':hits_org, f'hits{n_hit}_diff': hits_org-hits, 'post_proba':post_proba,
                'pre_proba':pre_proba, 'proba_diff': pre_proba-post_proba, 's':tt[0][0], 'p':tt[0][1], 'o':tt[0][2], 'index':ind,
                 'explanation':explanations[index]}
    print(measures)
    return measures


def measure_proba_diff(model_path, X, root, positive_rate, gpu, triple_selection):
    org_model_path = f"{model_path.split('calibrated')[0]}calibrated"
    print(f"Evaluating experiment: {org_model_path.split('/')[-1].split('calibrated')[0]} with {org_model_path.split('/')[-1].split('calibrated')[-1]}")
    print(f"Number at the end '1_rev' means scenario 2 with 1 explanation triple left, _rev means scenario 2 if present if not scenario 1.")
    root_pth = "/".join(org_model_path.split('/')[:-1])
    predictions_path = f"{root_pth}/predictions_{org_model_path.split('/')[-1]}"
    explanation_path = f"{root_pth}/explanations_{org_model_path.split('/')[-1].replace('_calibrated','')}.pkl"
    return _measure_proba_diff_from_paths(model_path, org_model_path, explanation_path, predictions_path, X, root, positive_rate, gpu, triple_selection)


if __name__ == "__main__":
    arguments = docopt(__doc__, version='Measure probability difference per TT') 
    root = arguments['--save']
    model_path = arguments['<model_path>']
    gpu = arguments['<gpu>']
    dataset = arguments['<dataset>']
    triple_selection = arguments['<triple_selection>']
    positive_rate = float(arguments['--rate'])
    if root is not None and not os.path.exists(root):
        os.mkdir(root)
    print(arguments)

    dataset_map = { "WN18": "load_wn18",
                    "FB15K": "load_fb15k",
                    "FB15K-237": "load_fb15k_237",
                    "WN18RR": "load_wn18rr",
                    "YAGO310": "load_yago3_10"}

    X = getattr(ampligraph.datasets, dataset_map[dataset.upper()])()
    result_file = f"{root}/measures_proba_diff_experiments.json"
    if os.path.exists(result_file):
        with open(result_file, 'r') as f:
            experiments = json.loads(f.read())
    else:
        experiments = {}
    if model_path not in experiments:
        measures = measure_proba_diff(model_path, X, root, positive_rate, gpu, triple_selection)
        experiment = {'name': model_path, 'triple_selection':triple_selection, 'dataset':dataset, 'positive_rate':positive_rate, 'metrics': measures}
        experiments[model_path] = experiment
        with open(result_file, 'w') as f:
            f.write(json.dumps(experiments, cls=NpEncoder))
    else:
        print(f"Experiment already done, see file: {result_file}")

