import matplotlib
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
matplotlib.rc('font', family='serif')
matplotlib.rc('text', usetex=True)
matplotlib.rcParams['text.usetex'] = True

from collections import defaultdict
import numpy as np


FULL_PLOT_WIDTH = 5.52 # Not used, as references. TODO: check paper template
WIDTH_CAPTION_GAP = -0.04
HEIGHT_CAPTION_GAP = 0.1

WIDTH_GAP, HEIGHT_GAP = -0.06, 0.3
SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 6, 12
LABEL_FONT=9
cmap = plt.get_cmap('tab10')


# PREV_DATASETS = ['amazon', 'retweet', 'taxi', 'taobao', 'stackoverflow', 'ehrshot', 'nlb1rep', 'lastfm', 'mimic']
CURR_DATASETS = ['amazon', 'retweet_jitter', 'taxi', 'taobao', 'stackoverflow', 'lastfm', 'mimic_jitter', 'ehrshot']


DATASET_LABELS = {
    'amazon': 'Amazon',
    'retweet_jitter': 'Retweet',
    'taxi': 'Taxi',
    'taobao': 'Taobao',
    # 'stackoverflow': 'S.O.',
    'stackoverflow': 'StackOverflow',
    'lastfm': 'Last.fm',
    'mimic_jitter': 'MIMIC-II',
    'ehrshot': 'EHRSHOT',
    # 'nlb1rep': 'N.S.',
}

# MODEL_NAMES = ['RMTPP', 'NHP', 'SAHP', 'THP', 'AttNHP', 'IFTPP', 'DLHP']
# roughly order them by model performance
# MODEL_NAMES = ['RMTPP', 'THP', 'SAHP', 'NHP', 'IFTPP', 'AttNHP', 'DLHP']  # previous ranking
MODEL_NAMES = ['RMTPP', 'THP', 'SAHP', 'AttNHP', 'NHP', 'IFTPP', 'DLHP']  # updated ranking

def my_formatter(x, pos):
    # return str('0') if x == 0 else str(x)
    return str(int(x)) if int(x) == x else str(np.around(x, 1))



def adjust(fig, left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0):
    fig.subplots_adjust(
        left   = left,  # the left side of the subplots of the figure
        right  = right,  # the right side of the subplots of the figure
        bottom = bottom,  # the bottom of the subplots of the figure
        top    = top,  # the top of the subplots of the figure
        wspace = wspace,  # the amount of width reserved for blank space between subplots
        hspace = hspace,  # the amount of height reserved for white space between subplots
    )

def disable_axis(ax):
    ax.set_zorder(-100)  # Avoids a visual rendering bug
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.setp(ax.spines.values(), color=None)


def plot_logl_main(total_ll, time_ll, mark_ll, anchor_model='RMTPP', save_fig=False,
                   model_subset=['amazon', 'stackoverflow', 'lastfm', 'ehrshot']):
    WIDTH_CAPTION_GAP = -0.04
    HEIGHT_CAPTION_GAP = 0.1

    WIDTH_GAP, HEIGHT_GAP = -0.06, 0.3
    LABEL_FONT = 9

    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A', '.', 'B'],
            ['Y1', '.', '.', '.'],
            ['Y2', 'C', '.', 'D'],
            ['Y2', 'X1', '.', 'X2'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        # figsize=(5.5, 2.5),
        # figsize=(7, 2.),
        figsize=(10, 2.),
    )

    disable_axis(axd['X1']); disable_axis(axd['X2']); disable_axis(axd['Y1']); disable_axis(axd['Y2']);
    axd['Y2'].set_ylabel('(Relative) Log-likelihood', fontsize=LABEL_FONT+2, labelpad=10, y=1.15)


    subplot_ids = list('ABCD')
    x_values = list(range(len(MODEL_NAMES)))
    for i, dataset in enumerate(model_subset):
        print(f'Current dataset: {dataset}')
        subplot_axd = axd[subplot_ids[i]]
        subplot_axd.set_xticks(x_values)

        logl, logl_time, logl_mark = [], [], []
        for model_id, model in enumerate(MODEL_NAMES):  # plot in this order
            if not anchor_model:
                logl.append(total_ll[model][dataset])
                logl_time.append(time_ll[model][dataset])
                logl_mark.append(mark_ll[model][dataset])
            else:
                if model == 'AttNHP' and dataset == 'ehrshot':
                    logl.append(None)
                    logl_time.append(None)
                    logl_mark.append(None)
                    continue
                logl.append(total_ll[model][dataset] - total_ll[anchor_model][dataset])
                logl_time.append(time_ll[model][dataset] - time_ll[anchor_model][dataset])
                logl_mark.append(mark_ll[model][dataset] - mark_ll[anchor_model][dataset])

        subplot_axd.plot(x_values, logl, '-o', color=cmap(0), lw=1.5, ms=3, label='Total')
        subplot_axd.plot(x_values, logl_time, '--o', color=cmap(1), lw=1.5, ms=3, label='Time')
        subplot_axd.plot(x_values, logl_mark, '--o', color=cmap(2), lw=1.5, ms=3, label='Mark')
        subplot_axd.set_title(DATASET_LABELS[dataset])
        subplot_axd.yaxis.set_major_formatter(FuncFormatter(my_formatter))

        if i >= len(model_subset) // 2:
            # subplot_axd.set_xticklabels(MODEL_NAMES, rotation=30)
            subplot_axd.set_xticklabels(MODEL_NAMES)
            # subplot_axd.xaxis.set_tick_params(labelsize=LABEL_FONT - 2)
            subplot_axd.xaxis.set_tick_params(labelsize=LABEL_FONT - 1)
        else:
            subplot_axd.set_xticklabels([''] * len(MODEL_NAMES))

        if dataset == 'amazon':
            # subplot_axd.legend(loc=[0.02, 0.31], fontsize=LABEL_FONT - 2)
            subplot_axd.legend(loc=[0.01, 0.18], fontsize=LABEL_FONT - 2)

    if save_fig:
        # plt.savefig('../plot/logl_lines_main.pdf', bbox_inches='tight')
        plt.savefig('../plot/logl_lines_main_final.pdf', bbox_inches='tight')

    plt.show()



def plot_logl(total_ll, time_ll, mark_ll, anchor_model='RMTPP', save_fig=False):
    WIDTH_CAPTION_GAP = -0.01
    HEIGHT_CAPTION_GAP = 0.1

    WIDTH_GAP, HEIGHT_GAP = -0.06, 0.
    LABEL_FONT = 10

    SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 6, 12

    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A'],
            ['Y1', '.'],
            ['Y2', 'B'],
            ['Y2', '.'],
            ['Y3', 'C'],
            ['Y3', '.'],
            ['Y4', 'D'],
            ['Y4', '.'],
            ['Y5', 'E'],
            ['Y5', '.'],
            ['Y6', 'F'],
            ['Y6', '.'],
            ['Y7', 'G'],
            ['Y7', '.'],
            ['Y8', 'H'],
            ['Y8', 'X1'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP,
                                       1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for y_idx in ['Y1', 'Y2', 'Y3', 'Y4', 'Y5', 'Y6', 'Y7', 'Y8']:
        disable_axis(axd[y_idx])
    disable_axis(axd['X1'])

    subplot_ids = list('ABCDEFGH')
    x_values = list(range(len(MODEL_NAMES)))
    for i, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        axd[f'Y{i+1}'].set_ylabel(DATASET_LABELS[dataset], fontsize = LABEL_FONT + 2, y = 0.55)
        subplot_axd = axd[subplot_ids[i]]
        subplot_axd.set_xticks(x_values)

        logl, logl_time, logl_mark = [], [], []
        for model_id, model in enumerate(MODEL_NAMES):  # plot in this order
            if not anchor_model:
                logl.append(total_ll[model][dataset])
                logl_time.append(time_ll[model][dataset])
                logl_mark.append(mark_ll[model][dataset])
            else:
                if model == 'AttNHP' and dataset == 'ehrshot':
                    logl.append(None)
                    logl_time.append(None)
                    logl_mark.append(None)
                    continue
                logl.append(total_ll[model][dataset] - total_ll[anchor_model][dataset])
                logl_time.append(time_ll[model][dataset] - time_ll[anchor_model][dataset])
                logl_mark.append(mark_ll[model][dataset] - mark_ll[anchor_model][dataset])

        subplot_axd.plot(x_values, logl, '-o', color=cmap(0), lw=1.5, ms=4, label='Total')
        subplot_axd.plot(x_values, logl_time, '--o', color=cmap(1), lw=1.5, ms=4, label='Time')
        subplot_axd.plot(x_values, logl_mark, '--o', color=cmap(2), lw=1.5, ms=4, label='Mark')


        if i == (len(DATASET_LABELS.keys()) - 1):
            subplot_axd.set_xticklabels(MODEL_NAMES)
        else:
            subplot_axd.set_xticklabels([''] * len(MODEL_NAMES))
        subplot_axd.yaxis.set_major_formatter(FuncFormatter(my_formatter))

        # if dataset == 'amazon':
        #     subplot_axd.legend(loc=[-0., 1.2], ncol=3, fontsize=LABEL_FONT - 2)
        if dataset == 'amazon':
            # subplot_axd.legend(loc=[0.02, 0.31], fontsize=LABEL_FONT - 2)
            subplot_axd.legend(loc=[0.15, 1.15], ncol=3, fontsize=LABEL_FONT)

    if save_fig:
        # plt.savefig('../plot/logl_lines_jitter.pdf', bbox_inches='tight')
        plt.savefig('../plot/logl_lines_jitter_final.pdf', bbox_inches='tight')
    plt.show()



def plot_logl_grid(total_ll, time_ll, mark_ll, anchor_model='RMTPP', save_fig=False):
    WIDTH_CAPTION_GAP = 0.02
    HEIGHT_CAPTION_GAP = 0.1

    WIDTH_GAP, HEIGHT_GAP = 0., 0.
    LABEL_FONT = 10

    SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 10, 8

    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A', '.', 'Y2', 'B'],
            ['Y1', '.', '.','Y2', '.'],
            ['Y3', 'C', '.','Y4', 'D'],
            ['Y3', '.', '.','Y4', '.'],
            ['Y5', 'E', '.','Y6', 'F'],
            ['Y5', '.', '.','Y6', '.'],
            ['Y7', 'G', '.','Y8', 'H'],
            ['Y7', 'X1', '.', 'Y8', 'X2'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, WIDTH_CAPTION_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for y_idx in ['Y1', 'Y2', 'Y3', 'Y4', 'Y5', 'Y6', 'Y7', 'Y8']:
        disable_axis(axd[y_idx])
    for x_idx in ['X1', 'X2']:
        disable_axis(axd[x_idx])

    subplot_ids = list('ABCDEFGH')
    x_values = list(range(len(MODEL_NAMES)))
    for i, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        axd[f'Y{i+1}'].set_ylabel(DATASET_LABELS[dataset], fontsize = LABEL_FONT + 2, y = 0.55)
        subplot_axd = axd[subplot_ids[i]]
        subplot_axd.set_xticks(x_values)

        logl, logl_time, logl_mark = [], [], []
        for model_id, model in enumerate(MODEL_NAMES):  # plot in this order
            if not anchor_model:
                logl.append(total_ll[model][dataset])
                logl_time.append(time_ll[model][dataset])
                logl_mark.append(mark_ll[model][dataset])
            else:
                if model == 'AttNHP' and dataset == 'ehrshot':
                    logl.append(None)
                    logl_time.append(None)
                    logl_mark.append(None)
                    continue
                logl.append(total_ll[model][dataset] - total_ll[anchor_model][dataset])
                logl_time.append(time_ll[model][dataset] - time_ll[anchor_model][dataset])
                logl_mark.append(mark_ll[model][dataset] - mark_ll[anchor_model][dataset])

        subplot_axd.plot(x_values, logl, '-o', color=cmap(0), lw=1.5, ms=4, label='Total Log-Likelihood')
        subplot_axd.plot(x_values, logl_time, '--o', color=cmap(1), lw=1.5, ms=4, label='Time Log-Likelihood')
        subplot_axd.plot(x_values, logl_mark, '--o', color=cmap(2), lw=1.5, ms=4, label='Mark Log-Likelihood')


        if i >= (len(DATASET_LABELS.keys()) - 2):
            subplot_axd.set_xticklabels(MODEL_NAMES)
        else:
            subplot_axd.set_xticklabels([''] * len(MODEL_NAMES))
        subplot_axd.yaxis.set_major_formatter(FuncFormatter(my_formatter))

        # if dataset == 'amazon':
        #     subplot_axd.legend(loc=[-0., 1.2], ncol=3, fontsize=LABEL_FONT - 2)
        if dataset == 'amazon':
            # subplot_axd.legend(loc=[0.02, 0.31], fontsize=LABEL_FONT - 2)
            subplot_axd.legend(loc=[0.28, 1.15], ncol=3, fontsize=LABEL_FONT)

    if save_fig:
        # plt.savefig('../plot/logl_lines_grid_jitter.pdf', bbox_inches='tight')
        plt.savefig('../plot/logl_lines_grid_jitter_final.pdf', bbox_inches='tight')
    plt.show()


if __name__ == '__main__':
    # time_ll = {
    #     'RMTPP': [0.010, -6.126, 0.622, 2.427, -0.780, -1.888, -1.542, 0.259, -0.108],
    #     'NHP': [2.196, -5.545, 0.728, 2.579, -0.703, -0.758, -1.540, 1.196, 0.269],
    #     'SAHP': [0.173, -5.670, 0.681, 2.612, -0.681, -1.779, -1.541, 0.600, -0.179],
    #     'THP': [-0.070, -5.850, 0.623, 2.242, -0.769, -1.890, -1.540, 0.220, -0.243],
    #     'AttNHP': [2.545, -5.638, 0.724, 2.665, -0.681, None, -1.537, 1.213, 0.069],
    #     'IFTPP': [2.482, -9.437, 0.736, 2.730, -0.660, -2.642, None, 1.290, 1.333],
    #     'DLHP': [2.638, -0.504, 0.738, 2.742, -0.636, 0.723, -1.538, 1.294, 3.937]
    # }
    #
    # mark_ll = {
    #     'RMTPP': [-2.148, -0.841, -0.275, -1.421, -1.623, -6.147, -4.299, -2.035, -0.253],
    #     'NHP': [-1.992, -0.764, -0.212, -1.416, -1.540, -3.149, -4.247, -1.774, -0.159],
    #     'SAHP': [-2.213, -0.775, -0.308, -1.411, -1.602, -5.066, -4.201, -2.100, -0.432],
    #     'THP': [-2.028, -0.785, -0.249, -1.451, -1.563, -5.294, -4.208, -1.936, -0.311],
    #     'AttNHP': [-1.938, -0.771, -0.225, -1.387, -1.498, None, -4.172, -1.771, -0.162],
    #     'IFTPP': [-1.989, -0.843, -0.282, -1.395, -1.565, -3.782, None, -1.763, -0.240],
    #     'DLHP': [-1.873, -0.807, -0.209, -1.410, -1.529, -2.912, -4.169, -1.790, -0.124]
    # }
    #
    #
    # total_ll = {
    #     'RMTPP': [-2.137, -6.967, 0.347, 1.006, -2.403, -8.035, -5.841, -1.776, -0.361],
    #     'NHP': [0.205, -6.310, 0.516, 1.163, -2.243, -3.907, -5.787, -0.578, 0.110],
    #     'SAHP': [-2.040, -6.445, 0.372, 1.201, -2.283, -6.845, -5.742, -1.500, -0.611],
    #     'THP': [-2.098, -6.635, 0.374, 0.791, -2.331, -7.183, -5.748, -1.716, -0.554],
    #     'AttNHP': [0.608, -6.409, 0.499, 1.278, -2.179, None, -5.709, -0.558, -0.093],
    #     'IFTPP': [0.493 , -10.280 , 0.454 , 1.335,  -2.224, -6.424, None, -0.472, 1.093],
    #     'DLHP': [0.765, -1.311, 0.528, 1.332, -2.165, -2.189, -5.707, -0.496, 3.814]
    # }

    # time_ll = {
    #     'RMTPP': [0.010 , -6.231 , 0.622 , 2.427 , -0.780   , 0.259 , -0.182 , -1.888],
    #     'NHP': [2.196 , -5.583 , 0.728 , 2.579 , -0.703   , 1.196 , 0.240 , -0.758],
    #     'SAHP': [0.173 , -5.895 , 0.681 , 2.612 , -0.681  , 0.600 , -0.298 , -1.779],
    #     'THP': [-0.070 , -5.867 , 0.623 , 2.242 , -0.769  , 0.220 , -0.277 , -1.890],
    #     'AttNHP': [2.545 , -5.688 , 0.724 , 2.665 , -0.681 , 1.213 ,  -0.017 , None],
    #     'IFTPP': [2.482 , -9.494 , 0.736 , 2.730 , -0.660 , 1.290 , 0.536 , -2.642],
    #     'DLHP': [2.638 , -5.600 , 0.738 , 2.742 , -0.636  , 1.294 , 1.345 , 0.723]
    # }
    #
    # mark_ll = {
    #     'RMTPP': [-2.148 , -0.939 , -0.275 , -1.421 , -1.623   , -2.035 , -0.298 , -6.147],
    #     'NHP': [-1.992 , -0.764 , -0.212 , -1.416 , -1.540   , -1.774 ,  -0.164 , -3.149],
    #     'SAHP': [-2.213 , -0.809 , -0.308 , -1.411 , -1.602   , -2.100 , -0.475 , -5.066],
    #     'THP': [-2.028 , -0.786 , -0.249 , -1.451 , -1.563   , -1.936 , -0.310 , -5.294],
    #     'AttNHP': [-1.938 , -0.771 , -0.225 , -1.387 , -1.498   , -1.771 ,-0.227 , None],
    #     'IFTPP': [-1.989 , -0.845 , -0.282 , -1.395 ,  -1.565   ,-1.763 , -0.237 , -3.782],
    #     'DLHP': [-1.873 , -0.767 , -0.209 , -1.410 , -1.529  , -1.790 , -0.114, -2.912]
    # }
    #
    # total_ll = {
    #     'RMTPP': [-2.137 , -7.169 , 0.347 , 1.006 , -2.403   , -1.776 , -0.480 , -8.035],
    #     'NHP': [0.205 , -6.346 , 0.516 , 1.163, -2.243  , -0.578 , 0.076 , -3.907],
    #     'SAHP': [-2.040 , -6.704 , 0.372 , 1.201 , -2.283  , -1.500 , -0.773 ,-6.845],
    #     'THP': [-2.098 , -6.652 , 0.374 , 0.791 , -2.331  , -1.716, -0.587 ,-7.183],
    #     'AttNHP': [0.608 , -6.459 , 0.499 , 1.278 , -2.179  , -0.558 ,  -0.244 , None],
    #     'IFTPP': [0.493 , -10.339 , 0.454 , 1.335 , -2.224 , -0.472 , 0.299 ,-6.424],
    #     'DLHP': [0.765 , -6.367 , 0.528 , 1.332 , -2.165  , -0.496 , 1.231 ,-2.189]
    # }

    time_ll = {
        'RMTPP': [0.011, -6.191, 0.622, 2.428, -0.797, 0.256, -0.188, -1.913],
        'NHP': [2.116, -5.584, 0.727, 2.578, -0.699, 1.198, 0.225, -0.821],
        'SAHP': [0.115, -5.872, 0.645, 2.604, -0.703, 0.489, -0.244, -1.801],
        'THP': [-0.068, -5.874, 0.621, 2.242, -0.772, 0.220, -0.271, -1.921],
        'AttNHP': [2.416, -5.726, 0.714, 2.654, -0.684, 1.203, 0.031, None],
        'IFTPP': [2.483, -9.500, 0.735, 2.708, -0.662, 1.277, 0.555, -2.640],
        'DLHP': [2.652, -5.598, 0.733, 2.719, -0.641, 1.257, 1.389, 0.382]
    }

    mark_ll = {
        'RMTPP': [-2.147, -0.908, -0.276, -1.425, -1.683, -2.035, -0.284, -6.168],
        'NHP': [-1.987, -0.764, -0.213, -1.421, -1.542, -1.772, -0.165, -3.144],
        'SAHP': [-2.189, -0.836, -0.346, -1.436,-1.638, -2.136, -0.433, -5.003],
        'THP': [-2.028, -0.785, -0.249, -1.451, -1.566, -1.932, -0.306, -5.287],
        'AttNHP': [-1.933, -0.773, -0.221, -1.395, -1.510, -1.795, -0.201, None],
        'IFTPP': [-1.988, -0.844, -0.282, -1.391, -1.571, -1.769, -0.239, -3.956],
        'DLHP': [-1.871, -0.767, -0.211, -1.415, -1.521, -1.814, -0.145, -2.893]
    }

    total_ll = {
        'RMTPP': [-2.136, -7.098, 0.346, 1.003, -2.480, -1.780, -0.472, -8.081],
        'NHP': [0.129, -6.348, 0.514, 1.157, -2.241, -0.574, 0.060, -3.966],
        'SAHP': [-2.074, -6.708, 0.298, 1.168, -2.341, -1.646, -0.677, -6.804],
        'THP': [-2.096, -6.659, 0.372, 0.790, -2.338, -1.712, -0.577, -7.208],
        'AttNHP': [0.484, -6.499, 0.493, 1.259, -2.194, -0.592, -0.170, None],
        'IFTPP': [0.496, -10.344, 0.453, 1.318, -2.233, -0.492, 0.317, -6.596],
        'DLHP': [0.781, -6.365, 0.522, 1.304, -2.163, -0.557, 1.243, -2.512]
    }


    time_ll_dict = defaultdict(dict)
    mark_ll_dict = defaultdict(dict)
    total_ll_dict = defaultdict(dict)
    for model in time_ll.keys():
        for i, data in enumerate(CURR_DATASETS):
            time_ll_dict[model][data] = time_ll[model][i]
            mark_ll_dict[model][data] = mark_ll[model][i]
            total_ll_dict[model][data] = total_ll[model][i]

    # plot_logl(total_ll_dict, time_ll_dict, mark_ll_dict, save_fig=True)
    plot_logl_main(total_ll_dict, time_ll_dict, mark_ll_dict, save_fig=True)
    plot_logl_grid(total_ll_dict, time_ll_dict, mark_ll_dict, save_fig=True)