import json
from collections import defaultdict

import pandas as pd
from matplotlib import pyplot as plt
import os
import numpy as np
from dataset.baseDataset import baseDataset, QuadruplesDataset
from datetime import datetime
import seaborn as sns
# sns.set_theme(style="whitegrid")
# plt.rcParams['font.family'] = 'serif'
         # controls default text sizes
# plt.rc('axes', titlesize=14)     # fontsize of the axes title
# plt.rc('axes', labelsize=12)     # fontsize of the x and y labels
# plt.rc('xtick', labelsize=10)    # fontsize of the tick labels
# plt.rc('ytick', labelsize=10)    # fontsize of the tick labels
# plt.rc('legend', fontsize=10)    # legend fontsize
# plt.rc('figure', titlesize=14)
# plt.grid(True, which='both', linestyle='--', linewidth=0.5)

large = 24; med = 14; small = 12
params = {'axes.titlesize': large,
          'legend.fontsize': med,
          'axes.labelsize': med,
          'xtick.labelsize': med,
          'ytick.labelsize': med,
          'figure.titlesize': large,
          'font.family': 'serif',
          'font.size': med}
plt.rcParams['pdf.use14corefonts'] = True
# plt.rc('font', size=12)
plt.rcParams.update(params)
sns.set_context(rc={'lines.linewidth': 2.5})
sns.set_style("whitegrid")
sns.set_palette("tab10")




def inductive_results(data_name, data_path, configs, mode=True):
    inductive_str = ''
    all_str = ''

    # create a plt subplot with two rows and two columns
    fig, axs = plt.subplots(1, 2, figsize=(10, 10))
    for conf in configs:
        train_mode = conf['train_mode']
        ewc_lambda = conf['ewc_lambda']
        buffer_size = conf['buffer_size']
        check_point_dir = conf['checkpoint_dir'] % (data_name, train_mode, ewc_lambda, buffer_size)
        rankd_dir = check_point_dir + 'ranks/'
        result_temp_fn = 'ranks_model_%d_data_%d_inductive_False.npy'

        start = 1
        if data_name == 'GDELT-S':
            end = 15
        elif data_name == 'ICEWS14':
            end = 33
        elif data_name == 'JCPenny':
            end = 9
        else:
            end = 14
        for j, inductive in enumerate([mode]):
            results = {}

            for k in ['MRR', 'HITS@1', 'HITS@3', 'HITS@10', 'size']:
                results[k] = np.zeros((end - start, end - start))
            for m in range(start, end):
                baseData = None
                for d in range(start, end):
                    if m != start and  m < d:
                        continue

                    snpashot = 'g%d' % d
                    testpath = os.path.join(data_path, '%s/test.txt' % snpashot)
                    trainpath = os.path.join(data_path, '%s/train.txt' % snpashot)
                    validpath = os.path.join(data_path, '%s/valid.txt' % snpashot)
                    if not baseData:
                        baseData = baseDataset(trainpath, testpath, os.path.join(data_path, 'stat.txt'), validpath,
                                               args=None)
                    else:
                        baseData.update(trainpath, testpath, validpath)

                    result_fn = result_temp_fn % (m, d)
                    print(result_fn)
                    try:
                        ranks = np.load(rankd_dir + result_fn, allow_pickle=True)
                    except FileNotFoundError:
                        print('%s %s' % (conf['prefix'], result_fn))
                        continue
                    # import IPython
                    # IPython.embed()
                    if inductive:
                        test_dataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r, None, False)
                        indices = []

                        for idx, q in enumerate(test_dataset):
                            s, _, o, _, _ = q
                            if s not in baseData.seen_entities or o not in baseData.seen_entities:
                                indices.append(idx)
                        ranks = ranks[indices]
                    for k in [1, 3, 10]:
                        results['HITS@%d'%k][m - start, d - start] = np.mean([r <= k  for r in ranks])
                    results['MRR'][m - start, d - start] = np.mean([1.0/r for r in ranks])
                    results['size'][m - start, d - start] = len(ranks)

            if inductive:
                if conf['legend'] == 'naive':
                    inductive_str += '%s &' % 'FRZ'
                    inductive_str += '%d &' % results['size'][0, -1]
                    inductive_str += '%.3f &' % results['HITS@10'][0, -1]
                    inductive_str += '%.3f &' % results['MRR'][0, -1]
                    inductive_str += '%.3f &' % np.mean(results['HITS@10'][0, :])
                    inductive_str += '%.3f &' % np.mean(results['MRR'][0, :])
                    inductive_str += '\n'

                inductive_str += '%s &' % conf['legend']
                inductive_str += '%.3f &' % results['HITS@10'][-1, -1]
                inductive_str += '%.3f &' % results['MRR'][-1, -1]
                inductive_str += '%.3f &' % np.mean(results['HITS@10'][-1, :])
                inductive_str += '%.3f &' % np.mean(results['MRR'][-1, :])

            else:
                if conf['legend'] == 'naive':
                    all_str += '%s &' % 'FRZ'
                    all_str += '%.3f &' % results['HITS@10'][0, -1]
                    all_str += '%.3f &' % results['MRR'][0, -1]
                    all_str += '%.3f &' % np.mean(results['HITS@10'][0, :])
                    all_str += '%.3f &' % np.mean(results['MRR'][0, :])
                    all_str += '\n'

                all_str += '%s &' % conf['legend']
                all_str += '%.3f &' % results['HITS@10'][-1, -1]
                all_str += '%.3f &' % results['MRR'][-1, -1]
                all_str += '%.3f &' % np.mean(results['HITS@10'][-1, :])
                all_str += '%.3f &' % np.mean(results['MRR'][-1, :])

            if inductive:
                inductive_str += '\n'
                print(inductive_str)
            else:
                all_str += '\n'
                print(all_str)

            keys = ['HITS@10', 'MRR']
            for i, k in enumerate(keys):
                y = []
                for r in range(1, end):
                    ind = r - 1
                    y.append(results[k][ind, :ind + 1].mean())

                x = list(range(len(y)))

                legend = conf['legend']
                axs[j, i].plot(x, y, label='%s' % legend)
                axs[j, i].set_title(k + (' inductive' if inductive else ' all'))

    axs[1].legend(loc='upper left', bbox_to_anchor=(1, 1), frameon=False)

    plt.tight_layout()
    plt.show()
    plt.savefig('../plots/%s_%s.png' % (data_name, datetime.today().strftime('%Y-%m-%d-%H:%M:%S')))

    print(inductive_str)
    print(all_str)


def plot_print_results_union(data_name, data_path, configs):
    inductive_str = ''
    all_str = ''

    #create a plt subplot with two rows and two columns
    fig, axs = plt.subplots(2, 1, figsize=(10, 10))
    for cid, conf in enumerate(configs):
        train_mode = conf['train_mode']
        ewc_lambda = conf['ewc_lambda']
        buffer_size = conf['buffer_size']
        check_point_dir = conf['checkpoint_dir']%(conf['dataset'],train_mode, ewc_lambda, buffer_size)
        rankd_dir = check_point_dir + 'ranks/'
        rank_temp_fn = 'ranks_model_%d_data_%d_inductive_False.npy'

        start = 1
        if data_name == 'GDELT-S':
            end = 15
        elif data_name == 'ICEWS14':
            end = 33
        elif data_name == 'JCPenny':
            end = 9
        else:
            end = 15
        for j, inductive in enumerate([False, True]):
            results_union = [[] for _ in range(end-start)]
            baseData = None
            for m in range(start, end):
                for d in range(start, end):
                    if m != start and m < d:
                        continue
                    rank_fn = rank_temp_fn%(m,d)
                    try:
                        ranks = np.load(rankd_dir + rank_fn, allow_pickle=True)
                        # print(len(ranks))
                    except Exception as e:
                        print(e)
                        # print('%s %s'%(conf['prefix'], result_fn))
                        continue

                    if inductive and m == end - 1:
                        snpashot = 'g%d' % d
                        testpath = os.path.join(data_path, '%s/test.txt' % snpashot)
                        trainpath = os.path.join(data_path, '%s/train.txt' % snpashot)
                        validpath = os.path.join(data_path, '%s/valid.txt' % snpashot)
                        if not baseData:
                            baseData = baseDataset(trainpath, testpath, os.path.join(data_path, 'stat.txt'), validpath,
                                                   args=None)
                        else:
                            baseData.update(trainpath, testpath, validpath)

                        test_dataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r, None, False)
                        indices = []

                        if len(test_dataset) != len(ranks):
                            print('length mismatch', len(test_dataset), len(ranks))
                        else:
                            for idx, q in enumerate(test_dataset):
                                s, _, o, _, _ = q
                                if s not in baseData.seen_entities or o not in baseData.seen_entities:
                                    indices.append(idx)
                            ranks = ranks[indices]


                    results_union[m - start].extend(ranks)


            if inductive:
                print(len(results_union[-1]))
                inductive_str += '%.3f &' % np.mean([r <= 10 for r in results_union[0]])
                inductive_str += '%.3f &' % np.mean([r <= 3 for r in results_union[0]])
                inductive_str += '%.3f &' % np.mean([r <= 1 for r in results_union[0]])
                inductive_str += '%.3f &' % np.mean([1.0 / r for r in results_union[0]])

                inductive_str += '%.3f &' % np.mean([r <= 10 for r in results_union[-1]])
                inductive_str += '%.3f &' % np.mean([r <= 3 for r in results_union[-1]])
                inductive_str += '%.3f &' % np.mean([r <= 1 for r in results_union[-1]])
                inductive_str += '%.3f &' % np.mean([1.0/r for r in results_union[-1]])



            else:

                all_str += '%s &' % conf['legend']
                all_str += '%.3f &' % np.mean([r <= 10 for r in results_union[-1]])
                all_str += '%.3f &' % np.mean([r <= 3 for r in results_union[-1]])
                all_str += '%.3f &' % np.mean([r <= 1 for r in results_union[-1]])
                all_str += '%.3f &' % np.mean([1.0 / r for r in results_union[-1]])


            if inductive:
                inductive_str += '\n'
                print(inductive_str)
            else:
                all_str += '\n'
                print(all_str)

            keys = ['HITS@10', 'MRR']
            for i, k in enumerate(keys):
                y = []

                for r in range(1, end):
                    if k == 'HITS@10':
                        y.append(np.mean([r <= 10 for r in results_union[r-1]]))
                    elif k == 'MRR':
                        y.append(np.mean([1.0/r for r in results_union[r-1]]))


                x = list(range(len(y)))
                legend = conf['legend']
                axs[i].plot(x, y, label='%s'%legend)
                axs[i].set_title(k + (' inductive' if inductive else ' all'))

    axs[1].legend(loc='upper left', bbox_to_anchor=(1, 1))
    plt.tight_layout()
    plt.show()
    # plt.savefig('../plots/%s_%s.png'%(data_name, datetime.today().strftime('%Y-%m-%d-%H:%M:%S')))
    print(all_str)
    print(inductive_str)



def plot_print_results(data_name, data_path, configs):
    inductive_str = ''
    all_str = ''

    #create a plt subplot with two rows and two columns
    fig, axs = plt.subplots(1, 2, figsize=(15, 8))
    for cid, conf in enumerate(configs):
        train_mode = conf['train_mode']
        ewc_lambda = conf['ewc_lambda']
        buffer_size = conf['buffer_size']
        check_point_dir = conf['checkpoint_dir']%(conf['dataset'],train_mode, ewc_lambda, buffer_size)
        rankd_dir = check_point_dir + 'ranks/'
        result_temp_fn = 'results__model_%d_data_%d_inductive_%r.npy'

        start = 1
        if data_name == 'GDELT-S':
            end = 15
        elif data_name == 'ICEWS14':
            end = 33
        elif data_name == 'JCPenny':
            end = 9
        else:
            end = 15
        # for j, inductive in enumerate([True, False]):
        for j, inductive in enumerate([False]):
            results = {}
            for k in ['MRR', 'HITS@1', 'HITS@3', 'HITS@10']:
                results[k] = np.zeros((end-start, end-start))
            for m in range(start, end):
                for d in range(start, end):
                    if m != start and m < d:
                        continue


                    result_fn = result_temp_fn%(m,d,inductive)
                    try:
                        result = np.load(rankd_dir + result_fn, allow_pickle=True)
                    except Exception as e:
                        print(e)
                        # print('%s %s'%(conf['prefix'], result_fn))
                        continue
                    # import IPython
                    # IPython.embed()
                    for k, v in result.item().items():
                        results[k][m-start, d-start] = v


            if inductive:
                # if conf['legend'] == 'naive':
                #     inductive_str += '%s &' % 'FRZ'
                #     inductive_str += '%.3f &' % results['HITS@10'][0, -1]
                #     inductive_str += '%.3f &' % results['MRR'][0, -1]
                #     inductive_str += '%.3f &' % np.mean(results['HITS@10'][0, :])
                #     inductive_str += '%.3f &' % np.mean(results['MRR'][0, :])
                #     inductive_str += '\n'

                inductive_str += '%s &'%conf['legend']
                inductive_str += '%.3f &' % results['HITS@10'][-1, -1]
                inductive_str += '%.3f &' % results['HITS@3'][-1, -1]
                inductive_str += '%.3f &' % results['HITS@1'][-1, -1]
                inductive_str += '%.3f &' % results['MRR'][-1, -1]
                inductive_str += '%.3f &' % np.mean(results['HITS@10'][-1, :])
                inductive_str += '%.3f &' % np.mean(results['HITS@3'][-1, :])
                inductive_str += '%.3f &' % np.mean(results['HITS@1'][-1, :])
                inductive_str += '%.3f &' % np.mean(results['MRR'][-1, :])

            else:
                if conf['legend'] == 'FT':
                    all_str += '%s &' % 'naive'
                    all_str += '%.3f &' % results['HITS@10'][0, -1]
                    all_str += '%.3f &' % results['HITS@3'][0, -1]
                    all_str += '%.3f &' % results['HITS@1'][0, -1]
                    all_str += '%.3f &' % results['MRR'][0, -1]
                    all_str += '%.3f &' % np.mean(results['HITS@10'][0, :])
                    all_str += '%.3f &' % np.mean(results['HITS@3'][0, :])
                    all_str += '%.3f &' % np.mean(results['HITS@1'][0, :])
                    all_str += '%.3f &' % np.mean(results['MRR'][0, :])
                    all_str += '\n'

                all_str += '%s &' % conf['legend']
                all_str += '%.3f &' % results['HITS@10'][-1, -1]
                all_str += '%.3f &' % results['HITS@3'][-1, -1]
                all_str += '%.3f &' % results['HITS@1'][-1, -1]
                all_str += '%.3f &' % results['MRR'][-1, -1]
                all_str += '%.3f &' % np.mean(results['HITS@10'][-1, :])
                all_str += '%.3f &' % np.mean(results['HITS@3'][-1, :])
                all_str += '%.3f &' % np.mean(results['HITS@1'][-1, :])
                all_str += '%.3f &' % np.mean(results['MRR'][-1, :])


            if inductive:
                inductive_str += '\n'
            else:
                all_str += '\n'


            keys = ['HITS@10', 'MRR']
            for i, k in enumerate(keys):
                y = []
                y2 = []
                for r in range(1, end):
                    ind = r - 1
                    y.append(results[k][ind, :ind+1].mean())
                    y2.append(results[k][0, :ind+1].mean())
                # for r in range(1, 5):
                #     ind = r - 1
                #     y=results[k][ind:, ind]
                #     # y = list(results[k][ind, :ind+1])
                #
                #     x = list(range(len(y)))
                #     x = [s+ind for s in x]
                #

                x = list(range(len(y)))
                # legend = (conf['legend'].split('_')[-4:])
                # legend = ' '.join(legend)
                legend = conf['legend']
                axs[i].plot(x, y, label='%s'%legend)
                if conf['legend'] == 'naive':
                    axs[i].plot(x, y2, label='naive')
                # axs[j, i].plot(x, y2, label=conf['train_mode'] + ' prev')
                # axs[i].set_title(k + (' inductive' if inductive else ' all'))
                axs[i].set_title(k)

    # axs[1].legend(loc='upper left', bbox_to_anchor=(1, 1))
    axs[1].legend(loc='best')
    fig.text(0.5, 0.01, 'Training Step', ha='center', va='center')
    fig.subplots_adjust(bottom=0.2)
    plt.tight_layout()
    # plt.show()
    # plt.savefig('../plots/%s_%s.png'%(data_name, datetime.today().strftime('%Y-%m-%d-%H:%M:%S')))
    plt.savefig('exp1.pdf', dpi=300)



    print(inductive_str)
    print(all_str)


def analyse_long_tail_performance(data_path, data_name):
    print('ANALYSE LONG TAIL')
    # load ranks from files
    bar_width = 0.1  # You can adjust this as needed
    num_configs = len(configs)


    for config_idx, conf in enumerate(configs):
        train_mode = conf['train_mode']
        ewc_lambda = conf['ewc_lambda']
        buffer_size = conf['buffer_size']
        check_point_dir = conf['checkpoint_dir'] % (conf['dataset'], train_mode, ewc_lambda, buffer_size)
        rank_dir = check_point_dir + 'ranks/'
        rank_temp_fn = 'ranks_model_%d_data_%d_inductive_%r.npy'

        start = 1
        if data_name == 'GDELT-S':
            end = 15
        elif data_name == 'ICEWS14':
            end = 33
        elif data_name == 'JCPenny':
            end = 9
        else:
            end = 14
        baseData = None
        results = []
        for d in range(start, end):
            snpashot = 'g%d' % d
            testpath = os.path.join(data_path, '%s/test.txt' % snpashot)
            trainpath = os.path.join(data_path, '%s/train.txt' % snpashot)
            validpath = os.path.join(data_path, '%s/valid.txt' % snpashot)
            if not baseData:
                baseData = baseDataset(trainpath, testpath, os.path.join(data_path, 'stat.txt'), validpath, args=None)
            else:
                baseData.update(trainpath, testpath, validpath)

            # load test quads
            # testDataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r)

            # get ent frequencies and quad weights
            freq = baseData.freq
            for m in range(end-1, end):
                result_fn = rank_temp_fn % (m, d, False)
                ranks = np.load(rank_dir + result_fn, allow_pickle=True)
                # create a histogram of ranks based on ent frequencies

                n = len(baseData.testQuadruples)
                for q, r in zip(baseData.testQuadruples, ranks):
                    results.append((freq[q[0]], freq[q[2]], np.mean([freq[q[0]], freq[q[2]]]), int(r<= 1),int(r <=3), int(r<=10),  r, 1.0/r, 1))
                for q, r in zip(baseData.testQuadruples, ranks[n:]):
                    results.append((freq[q[2]], freq[q[0]], np.mean([freq[q[0]], freq[q[2]]]), int(r<= 1),int(r <=3), int(r<=10),  r, 1.0/r, 1))

        # bins = [0, 15, 400, 1000, max(freq.values())]
        # bins = [0, 50, 100, 150, 500, 1000, max(freq.values())]  # Adjust this based on your needs
        # bins = [0, 4, 16, 64, 256, 1024, max(freq.values())]  # Adjust this based on your needs
        bins = np.logspace(np.log10(1), np.log10(max(freq.values())), num=10)
        # bins = [0] + list(bins)

        # labels = ['%d-%d'%(int(bins[i]), int(bins[i+1])) for i in range(len(bins)-1)]
        labels = ['%d'%(int(bins[i+1])) for i in range(len(bins)-1)]
        indices = np.arange(len(labels))
        print(bins, labels)

        df = pd.DataFrame(results, columns=['src_freq', 'dst_freq', 'mean_freq', 'hit@1', 'hit@3', 'hit@10', 'rank', 'mrr', 'count'])
        df['freq_bin'] = pd.cut(df['src_freq'], bins=bins, labels=labels, right=False)
        grouped = df.groupby('freq_bin').mean().reset_index()


        metric = 'hit@1'
        legend = (conf['legend'].split('_')[-4:])
        legend = ' '.join(legend)
        plt.bar(indices + config_idx * bar_width, grouped[metric], width=bar_width,
                label='%s' % legend)

    plt.gca().get_figure().set_size_inches(10, 5)
    # plt.title('Average %s by Frequency Bin for %s'%(metric, data_name))
    plt.xlabel('Source Frequency Bins')
    plt.ylabel('Average %s'%metric.upper())
    plt.xticks(indices + bar_width * (num_configs - 1) / 2,
               labels)  # Position x-axis labels in the center of grouped bars
    # plt.legend()
    plt.legend(loc='upper left', bbox_to_anchor=(1, 1))
    # plt.legend(loc='upper center', bbox_to_anchor=(0.5, 1.15), ncol=2)
    plt.grid(axis='y')
    plt.tight_layout()

    # plt.show()
    plt.savefig('exp3_%s.pdf'%metric, dpi=300)

def generate_latex_table(resultn_dict):
    latex_table = "\\begin{table}[h]\n"
    latex_table += "\\centering\n"
    latex_table += "\\caption{Your caption}\n"
    latex_table += "\\label{your-label}\n"
    latex_table += "\\begin{tabular}{|c|c|c|c|c|c|}\n"
    latex_table += "\\hline\n"
    latex_table += "\\multirow{2}{*}{Config} & \\multicolumn{3}{c|}{hit@1} & \\multicolumn{3}{c|}{hit@3} \\\\ \\cline{2-7}\n"
    latex_table += "                        & 0-15 & 15-1000 & 1000+ & 0-15 & 15-1000  & 1000+\\\\ \\hline\n"

    for config, df in resultn_dict.items():
        legend = config.split('_')[-4:]
        legend = ' '.join(legend)
        latex_table += "$\\texttt\{%s\}$"%legend
        for metric in ['hit@1', 'hit@10']:
            for freq_bin in ['0-15', '15-1000', '1000-2178']:
                value = df.loc[df['freq_bin'] == freq_bin, metric].values
                if len(value) > 0:
                    latex_table += f"& {value[0]:.2f} "
                else:
                    latex_table += "& - "
        latex_table += "\\\\ \\hline\n"

    latex_table += "\\end{tabular}\n"
    latex_table += "\\end{table}\n"

    return latex_table

# Assuming resultn_dict is filled with Pandas DataFrames


def create_long_tail_table(data_path, data_name, config):
    bar_width = 0.2  # You can adjust this as needed
    num_configs = len(configs)
    resultn_dict = {}
    result11_dict = {}
    resultnn_dict = {}
    for config_idx, conf in enumerate(configs):
        train_mode = conf['train_mode']
        ewc_lambda = conf['ewc_lambda']
        buffer_size = conf['buffer_size']
        check_point_dir = conf['checkpoint_dir'] % (data_name, train_mode, ewc_lambda, buffer_size)
        rank_dir = check_point_dir + 'ranks/'
        rank_temp_fn = 'ranks_model_%d_data_%d_inductive_%r.npy'

        start = 1
        if data_name == 'GDELT-S':
            end = 15
        elif data_name == 'ICEWS14':
            end = 33
        elif data_name == 'JCPenny':
            end = 9
        else:
            end = 14
        baseData = None
        result11 = []
        resultnn = []
        resultn = []
        all_results = []
        for d in range(start, end):
            snpashot = 'g%d' % d
            testpath = os.path.join(data_path, '%s/test.txt' % snpashot)
            trainpath = os.path.join(data_path, '%s/train.txt' % snpashot)
            validpath = os.path.join(data_path, '%s/valid.txt' % snpashot)
            if not baseData:
                baseData = baseDataset(trainpath, testpath, os.path.join(data_path, 'stat.txt'), validpath, args=None)
            else:
                baseData.update(trainpath, testpath, validpath)

            # load test quads
            # testDataset = QuadruplesDataset(baseData.testQuadruples, baseData.num_r)

            # get ent frequencies and quad weights
            freq = baseData.freq
            for m in range(d, end):
                result_fn = rank_temp_fn % (m, d, False)
                ranks = np.load(rank_dir + result_fn, allow_pickle=True)
                # create a histogram of ranks based on ent frequencies
                for q, r in zip(baseData.testQuadruples, ranks):
                    if m == 1 and d == 1:
                        result11.append((freq[q[0]], freq[q[2]], int(r <= 1), int(r <= 3), int(r <= 10), r, 1))
                    elif m == end - 1 and d == end -1:
                        resultnn.append((freq[q[0]], freq[q[2]], int(r <= 1), int(r <= 3), int(r <= 10), r, 1))
                    elif m == end - 1:
                        resultn.append((freq[q[0]], freq[q[2]], int(r <= 1), int(r <= 3), int(r <= 10), r, 1))


                bins = [0, 15, 1000, 2178]
                # bins = [0, 50, 100, 150, 500, 1000, max(freq.values())]  # Adjust this based on your needs
                # bins = np.logspace(np.log10(1), np.log10(max(freq.values())), num=20)
                # bins = [0] + list(bins)

                labels = ['%d-%d' % (int(bins[i]), int(bins[i + 1])) for i in range(len(bins) - 1)]
                # indices = np.arange(len(labels))
                # print(bins, labels)

        for result, result_dict in zip([result11, resultnn, resultn], [result11_dict, resultn_dict, resultnn_dict]):
            df = pd.DataFrame(result,
                                      columns=['src_freq', 'dst_freq', 'hit@1', 'hit@3', 'hit@10', 'rank', 'count'])
            df['freq_bin'] = pd.cut(df['src_freq'], bins=bins, labels=labels, right=False)
            grouped = df.groupby('freq_bin').mean().reset_index()

            result_dict[conf['legend']] = grouped[['freq_bin', 'hit@1', 'hit@3', 'hit@10']]

    # create a large table with all the results
    # import IPython
    # IPython.embed()
    result11_df = generate_latex_table(result11_dict)
    with open("result11_latex_table.tex", "w") as f:
        f.write(result11_df)
    resultn_df = generate_latex_table(resultn_dict)
    with open("resultn_latex_table.tex", "w") as f:
        f.write(resultn_df)

    resultnn_df = generate_latex_table(resultnn_dict)
    with open("resultnn_latex_table.tex", "w") as f:
        f.write(resultnn_df)








if __name__ == '__main__':
    configs = [
    ]

    with open('../scripts/eval_config_paper_ICEWS18.json', 'r') as f:
        test_configs = json.load(f)

    for conf in test_configs:
        conf['checkpoint_dir'] = 'path_to_checkpoint_dir' + 'checkpoints_%s'%conf['prefix'] + '%s_%s_%d/'

        if 'legend' not in conf:
            conf['legend'] = conf['prefix']
        if 'buffer_size' not in conf:
            conf['buffer_size'] = 0
        if 'ewc_lambda' not in conf:
            conf['ewc_lambda'] = 0.0
        configs.append(conf)

    # create_long_tail_table(data_path='data/ICEWS14', data_name='ICEWS14', config=configs)
    # plot_print_results_union(data_name='ICEWS14', configs=configs, data_path='data/ICEWS14')
    # plot_print_results_union(data_name='ICEWS1807', configs=configs, data_path='data/ICEWS1807')
    plot_print_results(data_name='ICEWS1807', configs=configs, data_path='data/ICEWS1807')
    # analyse_long_tail_performance(data_path='data/ICEWS14', data_name='ICEWS14')

