from tqdm import tqdm
import numpy as np
import pandas as pd
import time
import os
import json
import pickle as pkl
from scipy import stats
import tempfile

def get_correlation(post, pred, show=False, save=None):
    """Get Pearson coefficient to measure correlation between
       probabilities before and after retraining.

       Parameters
       ----------
       post: dataframe 
           df with predicted probabilities after training.
       pred: datadrame
           df with original predicted probabilities.
       show: bool
           wether to display chart.
       save: str
           save path for correlation plot.

       Returns
       -------
       pearson_coef: float
           Pearson coefficient.
       p_value: float
           P-value of Pearson coefficient.
       support: int
           Number of complete cases.
       all_values: int
           Number of all cases.
    """
    ii = np.flatnonzero(np.isfinite(post)) # exclude nans - complete case analysis
    pearson_coef, p_value = stats.pearsonr(pred[ii], np.array(post)[ii])
    msg = "The Pearson Correlation Coefficient is {},  with a P-value of P ={}, (support={}/{})"
    support = len(np.array(post)[ii])
    all_values =  len(post)
    print(msg.format(pearson_coef, p_value, support, all_values))    
    if save is not None or show:
        import seaborn as sns
        import matplotlib.pyplot as plt
        plt.figure(figsize=(14,10))
        p = sns.regplot(x=post, y=pred)
        font = 18
        p.set_yticks(p.get_yticks().tolist())
        p.set_yticklabels(["{:.2}".format(t) for t in p.get_yticks()], size=font)
        p.set_xticks(p.get_xticks().tolist())    
        p.set_xticklabels(["{:.2}".format(t) for t in p.get_xticks()], size=font)
        msg = "Probabilities before and after retraining \nPearson coef: {:.3}, p-value: {:.1}"
        plt.title(msg.format(pearson_coef, p_value), fontsize=font*1.4)
        plt.ylabel("Original", fontsize=font)
        plt.xlabel("After retraining", fontsize=font)
        sns.despine(bottom = True, left = True)
        sns.set_style("white", {'axes.grid' : False})
        if save is not None:
            plt.savefig(save)
        if show:
            plt.show()
    return pearson_coef, p_value, support, all_values


def remove_triples(X, triples):
    """Remove set of triples from an array.
    
       Parameters
       ----------
       X: np.array
           original array of triples to be modified.
       triples: np.array
           array of triples to be removed from the X.
           
       Returns
       -------
       updated: np.array
           original array without specific triples.
    """
    df = pd.DataFrame()
    df['s'] =  X[:,0]
    df['p'] =  X[:,1]
    df['o'] =  X[:,2]
    df.drop_duplicates()    
    inds = []
    for t in triples:
        elem = df[df['s'] == t[0]]
        elem = elem[elem['p'] == t[1]]
        elem = elem[elem['o'] == t[2]]
        if len(elem) > 0:
            ind = elem.index[0]
            inds.append(ind)

    df =  df.drop(inds)   
    updated = df.values
    return updated

def get_calibration_set(valid, train):
    """Update validation set to include only triples,
       which entities and predicates are also in train set.
       
       Parameters
       ----------
       valid: np.array
           original validation triples.
       train: np.array
           original training set.
           
       Returns
       -------
       updated: updated set of validation triples. 
    """
    nodes = set(train[:,0])
    nodes.update(set(train[:,1]))
    val_nodes = set(valid[:,0])
    val_nodes.update(set(valid[:,1]))
    new_nodes = nodes.intersection(val_nodes)
    preds = set(train[:,1])
    val_preds = set(valid[:,1])
    new_preds = preds.intersection(val_preds)
    trips = []
    for t in valid:
        if t[0] in new_nodes and t[2] in new_nodes and t[1] in new_preds:
            trips.append(t)
    updated = np.array(trips)
    return updated

def check_coverage(train, query):
    """Checks whether components of query triple 
       are present in the training set.
    
       Parameters
       ----------
       train: np.array
           training set.
       query: np.array
           query triple.
       
       Returns
       -------
       status: bool
           True - query triple components are represented,
           False - query triple components are not represented.
    """
    s = set(train[:,0])
    s.update(train[:,2])
    p = set(train[:,1])
    if query[0] in s and query[2] in s and query[1] in p:
        return True
    else:
        False
    
def retrain_model(model, X, explanation, get_explanation, verbose=False, local=True):
    """Retrain model.
    
       Parameters
       ----------
       model: EmbeddingModel
           model that clone needs to be retrained.
       X: dict
           dataset dictionary.
       explanation: dict
           explanation dictionary for which retraining will be done.
       get_explanation: function
           function to get explanation triples from explanation dictionary.           
       local: bool
           wether to retrain model on explanations only or on data without explanation (local=False).   

       Returns
       -------
       re_model: EmbeddingModel
           retrained model on dataset without explanation.
    """
    from ampligraph.latent_features import MODEL_REGISTRY
    # to avoid circural imports
    parameters = model.get_hyperparameter_dict()
    parameters['verbose'] = verbose
    parameters['batches_count'] = 1
    re_model = MODEL_REGISTRY[model.name](**parameters)
    explanation_graph = get_explanation(explanation)
    if len(explanation_graph) > 0:
        if not local:
            explanation_graph = remove_triples(explanation_graph, [explanation['query_triple']])
    #        print("Query: ", explanation['query_triple'])
    #        print(explanation_graph)
            ex = [tuple(e) for e in explanation_graph]
            if tuple(explanation['query_triple']) in ex:
                print("Query inside explanation Graph!")
            train = remove_triples(X['train'], explanation_graph)
            ok = check_coverage(train, explanation['query_triple'])
            if ok:            
                re_model.fit(train)
                calibration_set = get_calibration_set(X['valid'], train)
                return re_model
            else:
                return None
        else:
            ex = [tuple(e) for e in explanation_graph]
            if tuple(explanation['query_triple']) not in ex:
                print("Query not in explanation Graph - adding!")
                explanation_graph = np.concatenate([explanation_graph, explanation['query_triple'].reshape(-1,3)], 0) 

            explanation_graph = explanation_graph.reshape(-1,3)
            re_model.fit(explanation_graph)
            calibration_set = get_calibration_set(X['valid'], explanation_graph)
            if calibration_set is None or len(calibration_set) <= 0:
                print("Calibrating with full data - no triples in validation set that overlap with explanation graph.")
                calibration_set = explanation_graph
            re_model.calibrate(calibration_set, positive_base_rate=0.5) 
            return re_model
    else:
        return None

        
def retraining_experiment(model, X, explanations, get_explanation, path=tempfile.gettempdir(), local=True):
    """Retrain models for list of explanations.
       Function saves calculated probabilities in file.
    
       Parameters
       ----------
       model: EmbeddingModel
           original model for which explanations were generated.
       X: dict
           dataset dictionary.
       explanations: list
           list of generated explanations.
       get_explanation: function
           function to get explanation triples from explanation dictionary.
       path: str
           path were predictions will be saved.
    
       Returns
       -------
       pred_post: list
           calculated predictions for query triples after retraining.
           
    """
    proba_path = os.path.join(path, "post_probability.csv")
    pred_post = []
    if os.path.exists(proba_path):
        df = pd.read_csv(proba_path)
        pred_post = df['probability'].values.tolist()
    print("3. Start retraining experiment.")
    i = 0
    for explanation in tqdm(explanations[len(pred_post):]):
        re_model = retrain_model(model, X, explanation, get_explanation, local=local)
        if re_model is not None:
            pred_post.append(re_model.predict_proba(np.array([explanation['query_triple']]))[0])
        else:
            pred_post.append(None)            
        if i % 100 == 0:
            df = pd.DataFrame()
            df['probability'] = pred_post
            df.to_csv(proba_path, index=False)            
        i = i + 1
            
    df = pd.DataFrame()
    df['probability'] = pred_post
    df.to_csv(proba_path, index=False)        
    return pred_post

def calculate_differences(pred_pre, pred_post, path=tempfile.gettempdir(), plot=False, local=True):
    """Calculate difference between probabilities,
       between original model and retrained without explaining triples.
       
       Parameters
       ----------
       pred_pre: list
           probabilities beore retraining the model.
       pred_post: list
           probabilities after retraining the model.
       path: str
           path to the saving folder.
       plot: bool
           whether to plot differences on the bar plot.
           
       Returns
       -------
       diffs: list of differences between probabilities.
    """
    if plot:
        import matplotlib.pyplot as plt
        diffs = [pred-post for pred, post in zip(pred_pre.values.flatten().tolist(),pred_post) if post is not None]        
        plt.bar(range(len(diffs)), diffs)
        if local:
            msg = "on explanation graph"
        else:
            msg = "without explanations"
        plt.title("Probability difference between original model and retrined {}".format(msg))
        plt.xlabel("Instances")
        plt.ylabel("Probability Difference")
    else:
        diffs = [pred-post if post is not None else None for pred, post in zip(pred_pre.values.flatten().tolist(),pred_post)]        
    df = pd.DataFrame()
    df['diff'] = diffs
    df.to_csv(os.path.join(path, 'proba_differences.csv'), index=False)
    return diffs

def concat(a, b):
    """Concatenate two arrays if not empty, if one is empty return the other or empty array."""
    if len(a) > 0:
        if len(b) > 0:
            return np.concatenate([a, b], 0)
        else:
            return a
    elif len(b) > 0:
        return b
    else:
        return np.array([])


def evaluate_explainer(model, X, path=tempfile.gettempdir(), n=-1, explainer='example', params={}, name="evaluation", plot=False, local=True):
    """Runs fidelity experiment for a given model following ROAR 
       (remove and retrain evaluation Hooker et al. 2019).
    
       Parameters
       ----------
       model: EmbeddingModel
           original model for which explanations will be generated.
       X: dict
           dataset dictionary.
       path: str
           path where explanations and predictions will be saved.
       n: int
           number of explanations to generate default -1 -> all.
       explainer: str
           explainer to be used (example, random_explainer).
       Returns
       -------
       pred_post: list
          calculated predictions for query triples after retraining.
    """
    if explainer == 'example':
        get_explanation = lambda elem: concat(elem['prototype'], elem['examples'])
    elif explainer == 'random_explainer': 
        get_explanation = lambda elem: elem['relevant_triples']
    else:
        raise NotImplementedError
    explanation_path = os.path.join(path, 'explanations.pkl')
    proba_path = os.path.join(path, "pred_probability.csv")
    info_path = os.path.join(path, "explanation_info.json")    
    if os.path.exists(explanation_path)\
        and os.path.exists(proba_path)\
            and os.path.exists(info_path):
                print(f"Files will be loaded: {path}")
                with open(explanation_path, 'rb') as f:
                    explanations = pkl.loads(f.read())
                pred_pre = pd.read_csv(proba_path)
                with open(info_path) as f:
                    info = json.loads(f.read())
    else:
        print("Files will be created inside: {}".format(path))
        if not os.path.exists(path):
            os.mkdir(path)
        else:
            print("Path {} exists! You may want to delete it before running this script.".format(path))
        start = time.time()            
        print("1. Make predictions and collect explanations for them.")
        explanations = model.predict_explain(X['test'][:n], explainer=explainer, params=params)
        end = time.time()
        avg_size = np.mean([len(get_explanation(elem)) for elem in explanations])
        std_size = np.std([len(get_explanation(elem)) for elem in explanations])        
        all_trips = len(X['train'])
        avg_sparse = np.mean([1 - len(get_explanation(elem))/all_trips for elem in explanations]).astype('float64')
        std_sparse = np.std([1 - len(get_explanation(elem))/all_trips for elem in explanations]).astype('float64')
        if n == -1:
            n = len(X['test'])
        info = {'predict_explain_test_set_time [s]': end-start, 
                'predict_explain_time per triple [s]': (end-start)/n, 
               'n_elements':n,
               'average_expl_size [n triples]': avg_size,
               'std_expl_size [n triples]': avg_size,
               'average_sparsity': avg_sparse,
               'std_sparsity': std_sparse                
              }

        with open(info_path, 'w') as f:
            f.write(json.dumps(info))

        pred_pre = [e['probability'][0] for e in explanations]
        with open(explanation_path, 'wb') as f:
            dumped = pkl.dumps(explanations)
            f.write(dumped)
        df = pd.DataFrame()
        df['probability'] = pred_pre
        df.to_csv(proba_path, index=False)
    print("2. Evaluate probability drop and correlation after retraining models without explanations.")
    pred_post = retraining_experiment(model, X, explanations, get_explanation, path, local=local)
    print(len(pred_post))
    print(len(pred_pre))
    print(pred_post)
    diffs = calculate_differences(pred_pre, pred_post, path=path, plot=True, local=local)
    mean_diff = np.nanmean(diffs).astype('float64') if not np.isnan(diffs).all() else -1.
    std_diff = np.nanstd(diffs).astype('float64') if not np.isnan(diffs).all() else -1.
    info['mean_proba_diff'] = mean_diff
    info['std_proba_diff'] = std_diff
    if plot:
        save_path = os.path.join(path,f"correlation_{name}.pdf")
    else:
        save_path = None
    pearson_coef, p_value, support, all_values = get_correlation(pred_post, np.array(pred_pre), show=False,
                                                 save=save_path)

    info['pearson'] = pearson_coef
    info['pvalue'] = p_value
    info['support'] = float(support)
    info['all_values'] = float(all_values)
    with open(info_path, 'w') as f:
        f.write(json.dumps(info))

    return info
