"""
Retraining experiments:
Usage:
  retraining_experiment.py <model_path> <explanations_path> <experiment> <n_explanations> -r <rate> --gpu <gpu> [--monitfreq <frequency>] [--rev] [-s <root>]
  retraining_experiment.py -h | --help
  retraining_experiment.py --version

Options:
  --rev                  Whether to do a reversed experiment - remove everything except from the explanation per predicate.
  -s --save <root>       Specify whether and where to save the models and results.
  -r --rate <rate>       Specify positive base rate for calibrating model.
  -h --help              Show this screen.
  --version              Show version.
"""
# Copyright 2019-2021 The AmpliGraph Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
from docopt import docopt
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import ampligraph
from ampligraph import latent_features
from ampligraph.datasets import load_fb15k_237, load_wn18rr, load_yago3_10
from ampligraph.evaluation import evaluate_explainer, mrr_score, hits_at_n_score
from ampligraph.explanations import ExamplE
from ampligraph.utils import save_model, restore_model
from sklearn.calibration import calibration_curve
from scipy.special import expit
import matplotlib.pyplot as plt
import matplotlib
import pickle as pkl
from ampligraph.evaluation import remove_triples, get_calibration_set 
from ampligraph.compat import BACK_COMPAT_MODELS
from get_explanations import load_model_data
from training_time_experiment import train_model_with_probabilities_monit


def rank_triples_wrt_predictions(model, X, root, model_name):
    preds_path = f"{root}/predictions_{model_name}.pkl"
    if os.path.exists(preds_path):
        df = pd.read_csv(preds_path)
        return df

    predictions = model.predict_proba(X['test'])
    # sort triples according to predictions (descending order)
    df = pd.DataFrame()
    df['s'] = X['test'][:,0]
    df['p'] = X['test'][:,1]
    df['o'] = X['test'][:,2]
    df['pred'] = predictions
    df['index'] = list(range(0,len(df)))
    df = df.sort_values(by=['pred'], ascending=False)
    predictions = df.values
    if root is not None:
        df.to_csv(preds_path, index=None)
    return df


def roar_training(explanation, X, parameters, model_name, gpu, roar_path, rev, n, monit_freq=None, root='./', exp_name='', index='', explainer_name='none'):
    os.environ['CUDA_VISIBLE_DEVICES'] = gpu
    model_class = getattr(ampligraph.compat, model_name)
    re_model = model_class(**parameters)

    if len(explanation) > 0:
        if rev:
            # from triples with same predicate leave only the explanations
            pool = X['train'][X['train'][:,1] == explanation[0][1]]    
            train = remove_triples(X['train'], pool)
            train = np.concatenate([train, explanation[:,:3][:n]])
        else:   
            train = remove_triples(X['train'], explanation[:n])

        valid_set = get_calibration_set(X['valid'], train)
        test_set = get_calibration_set(X['test'], train)

        X_updated = {'train': train, 'valid': valid_set, 'test': test_set}
        if monit_freq is not None:
            name = f"{exp_name}_{index}_check_every_{monit_freq}_{n}_{'rev' if rev else ''}_roar" 
            result, re_model = train_model_with_probabilities_monit(re_model, X_updated, name, root, monit_freq, explainer_name)
        else:
            filter = np.concatenate((X_updated['train'], X_updated['valid'], X_updated['test']))
        
            re_model.fit(X_updated["train"], True,
                      {
                          'x_valid': X_updated['valid'][::2],
                          'criteria': 'mrr',
                          'x_filter': filter,
                          'stop_interval': 4,
                          'burn_in': 0,
                          'check_interval': 50
                      })
        save_model(re_model, roar_path)

#        valid_set = get_calibration_set(X['valid'], train)
#        test_set = get_calibration_set(X['test'], train)
#        if len(X['test']) == len(test_set):
#            X_updated = None
#        else:
#            X_updated = {'train': train, 'valid': valid_set, 'test': test_set}
    else:
        X_updated = None

    return re_model, X_updated


def get_explanation(index, explanations):
    return explanations[index]


def run_retraining_experiment(model_path, parameters, explanation_path, root, positive_rate, model_cls_name, gpu, rev, n, exclude_circular=True, monit_freq=None):
    model_name = model_path.split('/')[-1]
    with open(explanation_path, 'rb') as f:
        explanations = pkl.loads(f.read())
    model, X = load_model_data(model_path, root, positive_rate)
    ranked_df = rank_triples_wrt_predictions(model, X, root, model_name)

    for triple in tqdm(ranked_df.values):
        if exclude_circular and triple[0] == triple[2]:
            print(f"Circular triple detected, skipping... s:{triple[0]}==o:{triple[2]}")
            continue
        print(f"Target triple selected: {triple[4]} index, {triple[0:4]}")
        if rev:
            roar_path = f"{model_path}" +f"_{triple[4]}_ind_{n}_rev".replace("/",'_')
        else:
            roar_path = f"{model_path}" +f"_{triple[4]}_ind".replace("/",'_')
            
        if not os.path.exists(roar_path):
            explanation = get_explanation(triple[4], explanations)
            print(explanation)
            if len(explanation) == 1 and explanation[0][4] == -1:
                print(f"No explanations found for triple: {'-->'.join(triple)}, try changing explainer parameters.")
                continue
            re_model, X_updated = roar_training(explanation, X, parameters, model_cls_name, gpu, roar_path, rev, n, monit_freq, root, exp_name=model_name, index=triple[4])
            if X_updated is not None:
                print("Test set affected by removing triple, we will save the updated dataset for reference.")
                with open(f"{roar_path}_data", 'wb') as f:
                    f.write(pkl.dumps(X_updated))
        else:
            print(f"Path exists, skipping retraining... {roar_path}")


if __name__ == "__main__":
    arguments = docopt(__doc__, version='Get explanations') 
    root = arguments['--save']
    model_path = arguments['<model_path>']
    gpu = arguments['<gpu>']
    explanation_path = arguments['<explanations_path>']
    positive_rate = float(arguments['--rate'])
    experiment_name = arguments['<experiment>']
    monit_freq = int(arguments['<frequency>'])
    n_expls = arguments['<n_explanations>'] # 1 - remove in ROAR first triple, in rev-ROAR leave only 1st triple
    if n_expls == 'all':
        n_expls = -1
    else:
        n_expls = int(n_expls)
    rev = arguments['--rev']
    with open('./config.json') as f:
        config = json.loads(f.read())
    params = config['hyperparams'][experiment_name.split('_')[0].upper()][experiment_name.split('_')[1].upper()]
    if root is not None and not os.path.exists(root):
        os.mkdir(root)
    print(arguments)
    model_name = config['model_name_map'][experiment_name.split('_')[1].upper()]
    run_retraining_experiment(model_path, params, explanation_path, root, positive_rate, model_name, gpu, rev, n_expls, monit_freq=monit_freq)
