import pickle
import numpy as np
from sklearn.metrics import top_k_accuracy_score


DATASET_LABELS = {
    'amazon': 'Amazon',
    # 'retweet': 'Retweet',
    'retweet_jitter': 'Retweet',
    'taxi': 'Taxi',
    'taobao': 'Taobao',
    'stackoverflow': 'StackOverflow',
    'lastfm': 'Last.fm',
    'mimic_jitter': 'MIMIC-II',
    # 'ehrshot': 'EHRShot',
}
MODEL_NAMES = ['RMTPP', 'NHP', 'SAHP', 'THP', 'AttNHP', 'IntensityFree', 'DLHP']

for dataset in DATASET_LABELS.keys():
    print(f'Current dataset: {dataset}')
    for model in MODEL_NAMES:
        print(f'Current model: {model}')
        if dataset == 'ehrshot' and model == 'AttNHP':
            continue
        with open(f'../checkpoints/{dataset}/{model}/true_mark.pkl', 'rb') as f:
            true_mark = pickle.load(f)

        with open(f'../checkpoints/{dataset}/{model}/mark_pred.pkl', 'rb') as f:
            mark_pred = pickle.load(f)

        acc = np.mean(np.array(true_mark) == np.array(mark_pred))
        print(f'Acc: {acc}')

# dataset = 'ehrshot'
# models = ['RMTPP', 'NHP', 'SAHP', 'THP', 'IntensityFree', 'DLHP']
# model = 'RMTPP'
# with open(f'../checkpoints/{dataset}/{model}/true_mark.pkl', 'rb') as f:
#     true_mark = pickle.load(f)
#
# with open(f'../checkpoints/{dataset}/{model}/mark_conf.pkl', 'rb') as f:
#     mark_conf = pickle.load(f)
# mark_conf = np.array(mark_conf).reshape((-1, 668))
# acc = top_k_accuracy_score(np.array(true_mark), np.concatenate(mark_conf, axis=0), k=1, labels=np.array(list(range(668))))
# print(acc)