from ampligraph.datasets import load_wn18rr
from ampligraph.utils import restore_model
import os
import pandas as pd
import numpy as np
import pickle as pkl

### Top Triple in a test set of WN18RR had inveresed triple in the training set

def validate_triple(path_name):
    data = f"{path_name}_data"
    predictions = f"predictions_{path_name.split('calibrated')[0]}calibrated"

    X = load_wn18rr()
    model = restore_model(path_name)
    model.calibrate(X['valid'], positive_base_rate=0.5)
    preds = model.predict_proba(X['test'])
    df = pd.read_csv(predictions)
    t = [f"0{str(df['s'][0])}", f"{str(df['p'][0])}", f"00{str(df['o'][0])}"]
    print(t)
    model.predict_proba(np.asarray([np.asarray(t)]))
    df['roar_proba_0'] = preds[df['index'].values]
    df['diff_0'] = df['pred'] - df['roar_proba_0']
    r = np.corrcoef(df['pred'].values, df['roar_proba_0'].values)
    with open(data, 'rb') as f:
        Xx = pkl.loads(f.read())
    tt = [(x[0],x[1],x[2]) for x in Xx['train']]
    to = [(x[0],x[1],x[2]) for x in X['train']]
    print(set(to).difference(set(tt)))


path_name = "./pretrained_models/figures/TRANSE-WN18RR-2022-11-02-18-55-03_calibrated_7118002__derivationally_related_form_951206"
validate_triple(path_name)
