from collections import defaultdict
import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
matplotlib.rc('font', family='serif')
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.usetex'] = True

FULL_PLOT_WIDTH = 5.52 # Not used, as references. TODO: check paper template
WIDTH_CAPTION_GAP = 0.4
HEIGHT_CAPTION_GAP = 0.1

WIDTH_GAP, HEIGHT_GAP = 0.38, 0.25
SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 8, 4
LABEL_FONT=12
cmap = plt.get_cmap('tab10')

# PREV_DATASETS = ['amazon', 'retweet', 'taxi', 'taobao', 'stackoverflow', 'ehrshot', 'nlb1rep', 'lastfm', 'mimic_jitter']
CURR_DATASETS = ['amazon', 'retweet_jitter', 'taxi', 'taobao', 'stackoverflow', 'lastfm', 'mimic_jitter', 'ehrshot']

DATASET_LABELS = {
    'amazon': 'Amazon',
    # 'retweet': 'Retweet',
    'retweet_jitter': 'Retweet',
    'taxi': 'Taxi',
    'taobao': 'Taobao',
    'stackoverflow': 'StackOverflow',
    'lastfm': 'Last.fm',
    # 'mimic': 'MIMIC-II',
    'mimic_jitter': 'MIMIC-II',
    'ehrshot': 'EHRSHOT',
    # 'nlb1rep': 'N.S.',
}
# MARKERS = ['X', 'P', 'd', 'H', '*', '^', 's', '.']
MODEL_NAMES = ['RMTPP', 'NHP', 'SAHP', 'THP', 'AttNHP', 'IFTPP', 'DLHP']


def my_formatter(x, pos):
    return str('0') if x == 0 else str(x)


def adjust(fig, left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0):
    fig.subplots_adjust(
        left   = left,  # the left side of the subplots of the figure
        right  = right,  # the right side of the subplots of the figure
        bottom = bottom,  # the bottom of the subplots of the figure
        top    = top,  # the top of the subplots of the figure
        wspace = wspace,  # the amount of width reserved for blank space between subplots
        hspace = hspace,  # the amount of height reserved for white space between subplots
    )

def disable_axis(ax):
    ax.set_zorder(-100)  # Avoids a visual rendering bug
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.setp(ax.spines.values(), color=None)


def plot_time_ll_pce(time_ll, time_pce, save_fig=False, plot_isobar=False):
    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A', '.', 'B', '.', 'C', '.', 'D'],
            ['Y1', '.', '.', '.', '.', '.', '.', '.'],
            ['Y1', 'E', '.', 'F', '.', 'G', '.', 'H'],
            ['Y1', 'X1', '.', 'X2', 'X5', 'X3', '.', 'X4'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5']:
        disable_axis(axd[x_idx])
        if x_idx == "X5":
            axd[x_idx].set_xlabel("PCE", fontsize=LABEL_FONT+2, labelpad=10)


    disable_axis(axd['Y1']); axd['Y1'].set_ylabel("Time Log-Likelihood", fontsize=LABEL_FONT+2, labelpad=10, y=0.55)

    subplot_ids = list('ABCDEFGH')
    for i, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        subplot_axd = axd[subplot_ids[i]]

        for j, model in enumerate(MODEL_NAMES):
            if dataset == 'ehrshot' and model == 'AttNHP':
                continue
            subplot_axd.plot(time_pce[model][dataset], time_ll[model][dataset], marker='X', color=cmap(j),
                             label=model, alpha=0.6)

        subplot_axd.set_title(DATASET_LABELS[dataset], fontsize=LABEL_FONT - 1)

        # subplot_axd.set_xlim((0, 0.5))
        subplot_axd.set_xlim((0, None))
        if plot_isobar:
            x_lower, x_upper = subplot_axd.get_xlim()
            y_lower, y_upper = subplot_axd.get_ylim()
            subplot_axd.axline([x_lower, y_lower], [x_upper, y_upper], color='grey', ls='--', linewidth=1,
                               alpha=0.5)

        # subplot_axd.invert_xaxis()
        subplot_axd.xaxis.set_tick_params(labelsize=LABEL_FONT)
        subplot_axd.yaxis.set_tick_params(labelsize=LABEL_FONT)
        subplot_axd.xaxis.set_major_formatter(FuncFormatter(my_formatter))

        if dataset == 'amazon':
            subplot_axd.legend(loc=[-0.1 , 1.28], ncol=7, fontsize=LABEL_FONT - 4)
    if save_fig:
        # plt.savefig('../plot/time_ll_pce_topleft_jitter.pdf', bbox_inches='tight')
        plt.savefig('../plot/time_ll_pce_topleft_jitter_final.pdf', bbox_inches='tight')
    plt.show()



# def plot_time_ll_pce_horizontal(time_ll, time_pce, save_fig=False):
#     fig, axd = plt.subplot_mosaic(
#         mosaic=[
#             ['Y1', 'A', '.', 'B', '.', 'C', '.', 'D', '.', 'E', '.', 'F', '.', 'G'],
#             ['Y1', 'X1', '.', 'X2', '.', 'X3', '.', 'X4', '.', 'X5', '.', 'X6', '.', 'X7'],
#         ],
#         gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1,
#                                       WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
#                      "height_ratios": [1, HEIGHT_CAPTION_GAP]},
#         figsize=(SUB_PLOT_WIDTH * 2, SUB_PLOT_HEIGHT / 2),
#     )
#
#     for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5', 'X6', 'X7']:
#         disable_axis(axd[x_idx])
#         if x_idx == "X4":
#             axd[x_idx].set_xlabel("PCE", fontsize=LABEL_FONT+2, labelpad=10)
#
#
#     disable_axis(axd['Y1']); axd['Y1'].set_ylabel(r"$-\mathcal{L}_{Time}$", fontsize=LABEL_FONT+2, labelpad=10, y=0.6)
#
#     subplot_ids = list('ABCDEFG')
#     for i, dataset in enumerate(DATASET_LABELS.keys()):
#         print(f'Current dataset: {dataset}')
#         subplot_axd = axd[subplot_ids[i]]
#
#         for j, model in enumerate(MODEL_NAMES):
#             if dataset == 'ehrshot' and model == 'AttNHP':
#                 continue
#             subplot_axd.plot(time_pce[model][dataset], time_ll[model][dataset], marker='X', color=cmap(j), label=model)
#         subplot_axd.set_title(DATASET_LABELS[dataset], fontsize=LABEL_FONT - 1)
#
#         # subplot_axd.set_xlim((0, 0.5))
#         subplot_axd.set_xlim((0, None))
#         subplot_axd.invert_xaxis()
#         subplot_axd.xaxis.set_tick_params(labelsize=LABEL_FONT)
#         subplot_axd.yaxis.set_tick_params(labelsize=LABEL_FONT)
#         subplot_axd.xaxis.set_major_formatter(FuncFormatter(my_formatter))
#
#         if dataset == 'retweet':
#             subplot_axd.legend(loc=[-0.25, 1.28], ncol=7, fontsize=LABEL_FONT - 1)
#     if save_fig:
#         plt.savefig('../plot/time_ll_pce_horizontal.pdf', bbox_inches='tight')
#     plt.show()


def plot_mark_ll_ece(mark_ll, mark_pce, save_fig=False, plot_isobar=False):
    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A', '.', 'B', '.', 'C', '.', 'D'],
            ['Y1', '.', '.', '.', '.', '.', '.', '.'],
            ['Y1', 'E', '.', 'F', '.', 'G', '.', 'H'],
            ['Y1', 'X1', '.', 'X2', 'X5', 'X3', '.', 'X4'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5']:
        disable_axis(axd[x_idx])
        if x_idx == "X5":
            axd[x_idx].set_xlabel("ECE", fontsize=LABEL_FONT+2, labelpad=10)


    disable_axis(axd['Y1']); axd['Y1'].set_ylabel("Mark Log-Likelihood", fontsize=LABEL_FONT+2, labelpad=10, y=0.55)

    subplot_ids = list('ABCDEFGH')
    for i, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        subplot_axd = axd[subplot_ids[i]]

        for j, model in enumerate(MODEL_NAMES):
            if dataset == 'ehrshot' and model == 'AttNHP':
                continue
            subplot_axd.plot(mark_pce[model][dataset], mark_ll[model][dataset], marker='X', color=cmap(j),
                             label=model, alpha=0.6)
        subplot_axd.set_title(DATASET_LABELS[dataset], fontsize=LABEL_FONT - 1)

        # subplot_axd.set_xlim((0, 0.5))
        subplot_axd.set_xlim((0, None))
        if plot_isobar:
            x_lower, x_upper = subplot_axd.get_xlim()
            y_lower, y_upper = subplot_axd.get_ylim()
            subplot_axd.axline([x_lower, y_lower], [x_upper, y_upper], color='grey', ls='--', linewidth=1,
                               alpha=0.5)
        # subplot_axd.invert_xaxis()
        subplot_axd.xaxis.set_tick_params(labelsize=LABEL_FONT)
        subplot_axd.yaxis.set_tick_params(labelsize=LABEL_FONT)
        subplot_axd.xaxis.set_major_formatter(FuncFormatter(my_formatter))


        if dataset == 'amazon':
            subplot_axd.legend(loc=[-0.1, 1.28], ncol=7, fontsize=LABEL_FONT - 4)

    if save_fig:
        # plt.savefig('../plot/mark_ll_ece_topleft_jitter.pdf', bbox_inches='tight')
        plt.savefig('../plot/mark_ll_ece_topleft_jitter_final.pdf', bbox_inches='tight')
    plt.show()










if __name__ == '__main__':
    # time_ll = {
    #     'RMTPP': [0.010 , -6.126 , 0.622 , 2.427 , -0.780 , -1.888 , -1.542 , 0.259, -0.182],
    #     'NHP': [2.196 , -5.545 , 0.728 , 2.579 , -0.703 , -0.758 , -1.540 , 1.196, 0.240],
    #     'SAHP': [0.173 , -5.670 , 0.681 , 2.612 , -0.681 , -1.779 , -1.541 , 0.600, -0.298],
    #     'THP': [-0.070 , -5.850 , 0.623 , 2.242 , -0.769 , -1.890 , -1.540 , 0.220, -0.277 ],
    #     'AttNHP': [2.545 , -5.638 , 0.724 , 2.665 , -0.681 , None , -1.537 , 1.213, -0.017],
    #     'IFTPP': [2.482 , -9.437 , 0.736 , 2.730 , -0.660, -2.642 , None , 1.290, 0.536],
    #     'DLHP': [2.638 , -0.504 , 0.738 , 2.742 , -0.636 , 0.723 , -1.538 , 1.294, 1.345]
    # }
    #
    # mark_ll = {
    #     'RMTPP': [-2.148 , -0.841 , -0.275 , -1.421 , -1.623 , -6.147 , -4.299 , -2.035, -0.298],
    #     'NHP': [-1.992 , -0.764 , -0.212 , -1.416 , -1.540 , -3.149 , -4.247 , -1.774, -0.164],
    #     'SAHP': [-2.213 , -0.775 , -0.308 , -1.411 , -1.602 , -5.066 , -4.201 , -2.100, -0.475],
    #     'THP': [-2.028 , -0.785 , -0.249 , -1.451 , -1.563 , -5.294 , -4.208 , -1.936, -0.310],
    #     'AttNHP': [-1.938 , -0.771 , -0.225 , -1.387 , -1.498 , None , -4.172 , -1.771, -0.227],
    #     'IFTPP': [-1.989 , -0.843 , -0.282 , -1.395 ,  -1.565 ,   -3.782 , None ,-1.763, -0.237],
    #     'DLHP': [-1.873 , -0.807 , -0.209 , -1.410 , -1.529 , -2.912 , -4.169 , -1.790, -0.114]
    # }

    # time_pce = {
    #     'RMTPP': [0.1370 , 0.0861 , 0.0355 , 0.1018 , 0.0191 , 0.1327 , 0.0020 , 0.1155, 0.0385],
    #     'NHP': [0.0757 , 0.0015 , 0.0027 , 0.0738 , 0.0177 , 0.0824 , 0.0015 , 0.0477, 0.0605],
    #     'SAHP': [0.1086 , 0.0067 , 0.0173 , 0.0288 , 0.0114 , 0.1504 , 0.0009 , 0.1089, 0.0279],
    #     'THP': [0.1228 , 0.0624 , 0.0332 , 0.1632 , 0.0210 , 0.1456 , 0.0027 , 0.1090, 0.0121],
    #     'AttNHP': [0.0620 , 0.0112 , 0.0096 , 0.0317 , 0.0152 , None , 0.0015 , 0.0157, 0.0466],
    #     'IFTPP': [0.0174 , 0.2389 , 0.0044 , 0.0061 , 0.0050 , 0.1766, None, 0.0030, 0.0219],
    #     'DLHP': [0.0347 , 0.3273 , 0.0013 , 0.0205 , 0.0060 , 0.1247 , 0.0030 , 0.0118, 0.0894]
    # }
    #
    # mark_ece = {
    #     'RMTPP': [0.0641 , 0.0225 , 0.0262 , 0.0160 , 0.0136 , 0.0922 , 0.0011 , 0.0244, 0.0197],
    #     'NHP': [0.0675 , 0.0029 , 0.0081 , 0.0440 , 0.0102 , 0.0284 , 0.0012 , 0.0410, 0.0192],
    #     'SAHP': [0.0836 , 0.0033 , 0.0696 , 0.0300 , 0.0112 , 0.1109 , 0.0066 , 0.0855, 0.0577],
    #     'THP': [0.0202 , 0.0129 , 0.0174 , 0.0648 , 0.0077 , 0.1142 , 0.0054 , 0.0267, 0.0181],
    #     'AttNHP': [0.0288 , 0.0036 , 0.0044 , 0.0252 , 0.0121 , None , 0.0057 , 0.0050, 0.0279],
    #     # 'IFTPP': [0.0048 , 0.0019 , 0.0056 , 0.0075 , 0.0111 , 0.0062 , 0.0043 , 0.0066],
    #     'IFTPP': [0.0037 , 0.0065 , 0.0041 , 0.0149 , 0.0148  , 0.0201, None, 0.0059, 0.0140],
    #     'DLHP': [0.0100 , 0.0123 , 0.0046 , 0.0166 , 0.0201 , 0.0119 , 0.0042 , 0.0074, 0.0234]
    # }




    #
    # time_ll = {
    #     'RMTPP': [0.010, -6.231, 0.622, 2.427, -0.780, 0.259, -0.182, -1.888],
    #     'NHP': [2.196, -5.583, 0.728, 2.579, -0.703, 1.196, 0.240, -0.758],
    #     'SAHP': [0.173, -5.895, 0.681, 2.612, -0.681, 0.600, -0.298, -1.779],
    #     'THP': [-0.070, -5.867, 0.623, 2.242, -0.769, 0.220, -0.277, -1.890],
    #     'AttNHP': [2.545, -5.688, 0.724, 2.665, -0.681, 1.213, -0.017, None],
    #     'IFTPP': [2.482, -9.494, 0.736, 2.730, -0.660, 1.290, 0.536, -2.642],
    #     'DLHP': [2.638, -5.600, 0.738, 2.742, -0.636, 1.294, 1.345, 0.723]
    # }
    #
    # mark_ll = {
    #     'RMTPP': [-2.148, -0.939, -0.275, -1.421, -1.623, -2.035, -0.298, -6.147],
    #     'NHP': [-1.992, -0.764, -0.212, -1.416, -1.540, -1.774, -0.164, -3.149],
    #     'SAHP': [-2.213, -0.809, -0.308, -1.411, -1.602, -2.100, -0.475, -5.066],
    #     'THP': [-2.028, -0.786, -0.249, -1.451, -1.563, -1.936, -0.310, -5.294],
    #     'AttNHP': [-1.938, -0.771, -0.225, -1.387, -1.498, -1.771, -0.227, None],
    #     'IFTPP': [-1.989, -0.845, -0.282, -1.395, -1.565, -1.763, -0.237, -3.782],
    #     'DLHP': [-1.873, -0.767, -0.209, -1.410, -1.529, -1.790, -0.114, -2.912]
    # }
    #
    # # PREV_DATASETS = ['amazon', 'retweet', 'taxi', 'taobao', 'stackoverflow', 'ehrshot', 'nlb1rep', 'lastfm', 'mimic_jitter']
    # # CURR_DATASETS = ['amazon', 'retweet_jitter', 'taxi', 'taobao', 'stackoverflow', 'lastfm', 'mimic_jitter', 'ehrshot']
    #
    # time_pce = {
    #     'RMTPP': [0.1370 , 0.0420 , 0.0355 , 0.1018 , 0.0191 , 0.1155, 0.0385, 0.1331],
    #     'NHP': [0.0757 , 0.0015 , 0.0027 , 0.0738 , 0.0177 , 0.0477, 0.0605, 0.0822],
    #     'SAHP': [0.1086 , 0.0975 , 0.0173 , 0.0288 , 0.0114 , 0.1089, 0.0279, 0.1505],
    #     'THP': [0.1228 , 0.0571 , 0.0332 , 0.1632 , 0.0210 ,  0.1090, 0.0121, 0.1455],
    #     'AttNHP': [0.0620 , 0.0126 , 0.0096 , 0.0317 , 0.0152 , 0.0157, 0.0466, None],
    #     'IFTPP': [0.0174 , 0.2393 , 0.0044 , 0.0061 , 0.0050 , 0.0030, 0.0219, 0.1766],
    #     'DLHP': [0.0347 , 0.0040 , 0.0013 , 0.0205 , 0.0060 ,  0.0118, 0.0894, 0.1247]
    # }
    #
    # mark_ece = {
    #     'RMTPP': [0.0641 , 0.0589 , 0.0262 , 0.0160 , 0.0136 , 0.0244 , 0.0197, 0.0922],
    #     'NHP': [0.0675 , 0.0033 , 0.0081 , 0.0440 , 0.0102 , 0.0410 , 0.0192, 0.0284],
    #     'SAHP': [0.0836 , 0.0474 , 0.0696 , 0.0300 , 0.0112 ,  0.0855 , 0.0577, 0.1109],
    #     'THP': [0.0202 , 0.0120 , 0.0174 , 0.0648 , 0.0077 ,   0.0267, 0.0181, 0.1142],
    #     'AttNHP': [0.0288 , 0.0039 , 0.0044 , 0.0252 , 0.0121 , 0.0050, 0.0279, None],
    #     'IFTPP': [0.0037 , 0.0058 , 0.0041 , 0.0149 , 0.0148  , 0.0059, 0.0140, 0.0201],
    #     'DLHP': [0.0100 , 0.0072 , 0.0046 , 0.0166 , 0.0201 ,  0.0074, 0.0234, 0.0119]
    # }

    time_ll = {
        'RMTPP': [0.011, -6.191, 0.622, 2.428, -0.797, 0.256, -0.188, -1.913],
        'NHP': [2.116, -5.584, 0.727, 2.578, -0.699, 1.198, 0.225, -0.821],
        'SAHP': [0.115, -5.872, 0.645, 2.604, -0.703, 0.489, -0.244, -1.801],
        'THP': [-0.068, -5.874, 0.621, 2.242, -0.772, 0.220, -0.271, -1.921],
        'AttNHP': [2.416, -5.726, 0.714, 2.654, -0.684, 1.203, 0.031, None],
        'IFTPP': [2.483, -9.500, 0.735, 2.708, -0.662, 1.277, 0.555, -2.640],
        'DLHP': [2.652, -5.598, 0.733, 2.719, -0.641, 1.257, 1.389, 0.382]
    }

    mark_ll = {
        'RMTPP': [-2.147, -0.908, -0.276, -1.425, -1.683, -2.035, -0.284, -6.168],
        'NHP': [-1.987, -0.764, -0.213, -1.421, -1.542, -1.772, -0.165, -3.144],
        'SAHP': [-2.189, -0.836, -0.346, -1.436, -1.638, -2.136, -0.433, -5.003],
        'THP': [-2.028, -0.785, -0.249, -1.451, -1.566, -1.932, -0.306, -5.287],
        'AttNHP': [-1.933, -0.773, -0.221, -1.395, -1.510, -1.795, -0.201, None],
        'IFTPP': [-1.988, -0.844, -0.282, -1.391, -1.571, -1.769, -0.239, -3.956],
        'DLHP': [-1.871, -0.767, -0.211, -1.415, -1.521, -1.814, -0.145, -2.893]
    }

    time_pce = {
        'RMTPP': [0.1367 , 0.0771 , 0.0351 , 0.0997 , 0.0178 , 0.1154, 0.0423, 0.1307],
        'NHP': [0.0816 , 0.0014 , 0.0127 , 0.0666 , 0.0142 , 0.0460, 0.0587, 0.0742],
        'SAHP': [0.1079 , 0.0623 , 0.0175 , 0.0311 , 0.0138 , 0.1044, 0.0200, 0.2002],
        'THP': [0.1238 , 0.0560 , 0.0333 , 0.1632 , 0.0192 ,  0.1097, 0.0105, 0.1373],
        'AttNHP': [0.0650 , 0.0108 , 0.0053 , 0.0285 , 0.0146 , 0.0136, 0.0484, None],
        'IFTPP': [0.0154 , 0.2335 , 0.0047 , 0.0061 , 0.0068 , 0.0030, 0.0189, 0.1514],
        'DLHP': [0.0581 , 0.0091 , 0.0015 , 0.0191 , 0.0082 ,  0.0073, 0.1035, 0.1216]
    }

    mark_ece = {
        'RMTPP': [0.0642 , 0.0206 , 0.0224 , 0.0152 , 0.0209 , 0.0249 , 0.0287, 0.0822],
        'NHP': [0.0802, 0.0041 , 0.0077 , 0.0528 , 0.0115 , 0.0273 , 0.0192, 0.0421],
        'SAHP': [0.0893 , 0.0866 , 0.0696 , 0.0287 , 0.0148 ,  0.0779 , 0.0558, 0.0480],
        'THP': [0.0214 , 0.0127 , 0.0165 , 0.0651 , 0.0094 ,   0.0424, 0.0242, 0.0995],
        'AttNHP': [0.0235 , 0.0033 , 0.0057 , 0.0256 , 0.0131 , 0.0051, 0.0197, None],
        'IFTPP': [0.0060 , 0.0293 , 0.0087 , 0.0148 , 0.0141  , 0.0059, 0.0166, 0.0192],
        'DLHP': [0.0130 , 0.0045 , 0.0052 , 0.0148 , 0.0174 ,  0.0031, 0.0131, 0.0247]
    }

    time_ll_dict = defaultdict(dict)
    mark_ll_dict = defaultdict(dict)
    time_pce_dict = defaultdict(dict)
    mark_ece_dict = defaultdict(dict)
    for model in time_ll.keys():
        for i, data in enumerate(CURR_DATASETS):
            time_ll_dict[model][data] = time_ll[model][i]
            mark_ll_dict[model][data] = mark_ll[model][i]
            time_pce_dict[model][data] = time_pce[model][i]
            mark_ece_dict[model][data] = mark_ece[model][i]


    # plot_time_ll_pce_horizontal(time_ll_dict, time_pce_dict, save_fig=True)


    plot_time_ll_pce(time_ll_dict, time_pce_dict, save_fig=True, plot_isobar=False)
    plot_mark_ll_ece(mark_ll_dict, mark_ece_dict, save_fig=True, plot_isobar=False)