"""
Run experiments:
Usage:
  run_experiments.py --gpu <gpu> 
  run_experiments.py -h | --help
  run_experiments.py --version

Options:
  -h --help              Show this screen.
  --version              Show version.
"""
# Copyright 2022 The Authors. All Rights Reserved.
#
# This file is Licensed under the Apache License, Version 2.0.
# A copy of the Licence is available in LICENCE, or at:
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
from docopt import docopt
import os
import tensorflow as tf
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import numpy as np
import pandas as pd
from tqdm import tqdm
import json
import ampligraph
from ampligraph import latent_features
from ampligraph.datasets import load_fb15k_237, load_wn18rr, load_yago3_10
from ampligraph.evaluation import evaluate_explainer, mrr_score, hits_at_n_score
from ampligraph.explanations import ExamplE
from ampligraph.utils import save_model, restore_model
from sklearn.calibration import calibration_curve
from scipy.special import expit
import matplotlib.pyplot as plt
import matplotlib
import pickle as pkl
from datetime import datetime
from ampligraph.evaluation import remove_triples, get_calibration_set 
from ampligraph.compat import BACK_COMPAT_MODELS
from get_explanations import load_model_data, get_explanations
from training_time_experiment import train_model_with_probabilities_monit
from retraining_experiment import rank_triples_wrt_predictions, get_explanation, roar_training
from npencoder import NpEncoder

def get_experiment_config():
    date_time =datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
    save_root = f'./results'
    if not os.path.exists(save_root):
        os.mkdir(save_root)
    cfg_path = os.path.join(save_root, 'config.json')
    if not os.path.exists(cfg_path):
        config =  {
                'n': 1,
                'explainers': ['example', 'random'],
                'reruns': ['roar-all', 'rev-roar-all', 'roar-1', 'rev-roar-1'],
                'save_root': save_root,
                'models': {'transe_wn18rr': '../pretrained_models/TRANSE-WN18RR-2022-11-02-18-55-03_calibrated',
                           'transe_fb15k-237':'../pretrained_models/TRANSE-FB15K-237-2022-11-02-19-28-40_calibrated'},
                'runs': [date_time],
                'misc': {'positive_rate':0.5, 'model_config': '/home/ajanik/AmpliGraph-Lab/experiments/config.json',
                         'monit_freq': 10}
            }
        with open(cfg_path, 'w') as f:
            f.write(json.dumps(config))
    else:
        with open(cfg_path, 'r') as f:
            config = json.loads(f.read())
        config['runs'].append(date_time)
    return config


def main(gpu):
    config = get_experiment_config()
    print(config)
    save_root = config['save_root']
    if not os.path.exists(f'{save_root}/results.json'):
        results = {}
        with open(f'{save_root}/results.json', 'w') as f:
            f.write(json.dumps(results, cls=NpEncoder))
    else:
         with open(f'{save_root}/results.json', 'r') as f:
            results = json.loads(f.read())
    for model_name, model_path in config['models'].items():
        if model_name not in results:
            results[model_name] = {}
        if not os.path.exists(f'{save_root}/{model_name}'):
            os.mkdir(f'{save_root}/{model_name}')
        root = f'{save_root}/{model_name}'
        model, X = load_model_data(model_path, root, config['misc']['positive_rate'])
        ranked_trips = rank_triples_wrt_predictions(model, X, root, model_name)

        for explainer_name in config['explainers']:
            if not os.path.exists(os.path.join(root, explainer_name)):
                os.mkdir(os.path.join(root, explainer_name))
            if not explainer_name in results[model_name]:
                explanations, elapsed = get_explanations(model, X, model_name, root, explainer_name)
                results[model_name][explainer_name] = {'total_time_get_explanation': elapsed, 'time-per-triple': elapsed/len(X['test'])}
                with open(f'{save_root}/results.json', 'w') as f:
                    f.write(json.dumps(results, cls=NpEncoder))
            else:
                with open(os.path.join(root, f'explanations_{explainer_name}_{model_name}.pkl'), 'rb') as f:
                    explanations = pkl.loads(f.read())
            results[model_name]['selected'] = []
            for triple in tqdm(ranked_trips.values):
                if len(results[model_name]['selected']) >= config['n']:
                    break
                else:
                    if triple[0] == triple[2]:
                        print(f"Circular triple detected, skipping... s:{triple[0]}==o:{triple[2]}")
                        continue
                    print(f"Target triple selected: {triple[4]} index, {triple[0:4]}")
                    for rerun in config['reruns']:
                        n_expls = -1 if 'all' in rerun else 1
                        rev = 1 if 'rev' in rerun else 0
                        roar_path = os.path.join(root, explainer_name, f"{triple[4]}_ind_{explainer_name}_{n_expls}_{'rev' if rev == 1 else ''}roar".replace("/",'_'))
                        
                        if not os.path.exists(roar_path):
                            explanation = get_explanation(triple[4], explanations)
                            print(explanation)
                            if len(explanation) == 1 and explanation[0][4] == -1:
                                print(f"No explanations found for triple: {'-->'.join(triple)}, try changing explainer parameters.")
                                continue
                            if tuple(triple) not in set([tuple(t) for t in results[model_name]['selected']]):
                                results[model_name]['selected'].append(triple)
                            with open(f'{save_root}/results.json', 'w') as f:
                                f.write(json.dumps(results, cls=NpEncoder))
                            with open(config['misc']['model_config']) as f:
                                model_config = json.loads(f.read())
                            model_cls_name = model_config['model_name_map'][model_name.split('_')[0].upper()]
                            params = model_config['hyperparams'][model_name.split('_')[1].upper()][model_name.split('_')[0].upper()]
                            re_model, X_updated = roar_training(explanation, X, params, model_cls_name, gpu, roar_path, rev, n_expls, config['misc']['monit_freq'], root, exp_name=model_name, index=triple[4], explainer_name=explainer_name)
                            if X_updated is not None:
                                print("Test set affected by removing triple, we will save the updated dataset for reference.")
                                with open(f"{roar_path}_data", 'wb') as f:
                                    f.write(pkl.dumps(X_updated))
                            if rerun not in results[model_name][explainer_name]:
                                results[model_name][explainer_name][rerun] = roar_path
                            with open(f'{save_root}/results.json', 'w') as f:
                                f.write(json.dumps(results, cls=NpEncoder))
                        else:
                            print(f"Path exists, skipping retraining... {roar_path}")
                            if rerun not in results[model_name][explainer_name]:
                                results[model_name][explainer_name][rerun] = roar_path
                            with open(f'{save_root}/results.json', 'w') as f:
                               f.write(json.dumps(results, cls=NpEncoder))
                               
            with open(f'{save_root}/results.json', 'w') as f:
                f.write(json.dumps(results, cls=NpEncoder))


if __name__ == "__main__":
    arguments = docopt(__doc__, version='Run Experiments') 
    gpu = arguments['<gpu>']
    print(f"Running on GPU number: {gpu}")
    main(gpu)
