# TODO: check some points are not visible
import pickle
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
matplotlib.rc('font', family='serif')
matplotlib.rc('text', usetex=True)

PCE_BIN, ECE_BIN = 50, 20

FULL_PLOT_WIDTH = 5.52 # Not used, as references. TODO: check paper template
WIDTH_CAPTION_GAP = 0.1
HEIGHT_CAPTION_GAP = 0.1

WIDTH_GAP, HEIGHT_GAP = 0.075, 0.075
SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 8, 4
LABEL_FONT=9
# PLOT_AVG_ACC=True
cmap = plt.get_cmap('tab10')


DATASET_LABELS = {
    'amazon': 'Amazon',
    # 'retweet': 'Retweet',
    'retweet_jitter': 'Retweet',
    'taxi': 'Taxi',
    'taobao': 'Taobao',
    'stackoverflow': 'StackOverflow',
    'lastfm': 'Last.fm',
    # 'mimic': 'MIMIC-II',
    'mimic_jitter': 'MIMIC-II',
    'ehrshot': 'EHRSHOT',
    # 'nlb1rep': 'N.S.',
}
MODEL_NAMES = ['RMTPP', 'NHP', 'SAHP', 'THP', 'AttNHP', 'IntensityFree', 'DLHP']


def adjust(fig, left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0):
    fig.subplots_adjust(
        left   = left,  # the left side of the subplots of the figure
        right  = right,  # the right side of the subplots of the figure
        bottom = bottom,  # the bottom of the subplots of the figure
        top    = top,  # the top of the subplots of the figure
        wspace = wspace,  # the amount of width reserved for blank space between subplots
        hspace = hspace,  # the amount of height reserved for white space between subplots
    )

def disable_axis(ax):
    ax.set_zorder(-100)  # Avoids a visual rendering bug
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.setp(ax.spines.values(), color=None)

# def plot_pce_horizontal(save_fig=False):
#     fig, axd = plt.subplot_mosaic(
#         mosaic=[
#             ['Y1', 'A', '.', 'B', '.', 'C', '.', 'D', '.', 'E', '.', 'F', '.', 'G', '.', 'H'],
#             ['Y1', 'X1', '.', 'X2', '.', 'X3', '.', 'X4', 'X9', 'X5', '.', 'X6', '.', 'X7', '.', 'X8'],
#         ],
#         gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1,
#                                       WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
#                      "height_ratios": [1, HEIGHT_CAPTION_GAP]},
#         figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
#     )
#
#     for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8', 'X9']:
#         disable_axis(axd[x_idx])
#         if x_idx == "X9":
#             axd[x_idx].set_xlabel("Predicted CDF", fontsize=LABEL_FONT+2, labelpad=10)
#
#
#     disable_axis(axd['Y1']); axd['Y1'].set_ylabel("Empirical CDF", fontsize=LABEL_FONT+2, labelpad=10)
#     # disable_axis(axd['Y2']); axd['Y2'].set_ylabel("Empirical CDF of dts", fontsize=LABEL_FONT + 2, labelpad=10)
#
#     pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
#     probs = pm.squeeze()
#
#     subplot_ids = list('ABCDEFHG')
#     for dataset_id, dataset in enumerate(DATASET_LABELS.keys()):
#         print(f'Current dataset: {dataset}')
#         subplot_axd = axd[subplot_ids[dataset_id]]
#
#         subplot_axd.set_xticks([0.0, 0.5, 1.0])
#         subplot_axd.set_yticks([0.0, 0.5, 1.0])
#         subplot_axd.set_xticks([0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9], minor=True)
#         subplot_axd.set_yticks([0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9], minor=True)
#
#         if subplot_ids[dataset_id] in ['A', 'E']:
#             subplot_axd.set_yticklabels(["0", "0.5", "1"])
#         else:
#             subplot_axd.set_yticklabels(["", "", ""])
#         if subplot_ids[dataset_id] in ['E', 'F', 'G', 'H']:
#             subplot_axd.set_xticklabels(["0", "0.5", "1"])
#         else:
#             subplot_axd.set_xticklabels(["", "", ""])
#
#         subplot_axd.set_xlim((0, 1))
#         subplot_axd.set_ylim((0, 1))
#
#         for model_id, model_name in enumerate(MODEL_NAMES):
#             print(f'Current model: {model_name}')
#
#             if dataset == 'ehrshot' and model_name == 'AttNHP':
#                 continue
#             with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
#                 dt_cdf = pickle.load(f)
#
#             num_events = len(dt_cdf)
#             indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
#             subplot_axd.plot(indicator_eval, probs, label=model_name, color=cmap(model_id), alpha=0.7)
#             subplot_axd.plot(probs, probs, linestyle='--', color='tab:gray')
#         subplot_axd.set_title(DATASET_LABELS[dataset])
#         if dataset =='amazon':
#             subplot_axd.legend(loc=[0.1, 1.28 ], ncol=7, fontsize=LABEL_FONT - 2)
#     if save_fig:
#         plt.savefig('../plot/pce_all_horizontal.pdf', bbox_inches='tight')
#     plt.show()

def plot_pce_all(save_fig=False):
    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A', '.', 'B', '.', 'C', '.', 'D'],
            ['Y1', '.', '.', '.', '.', '.', '.', '.'],
            ['Y1', 'E', '.', 'F', '.', 'G', '.', 'H'],
            ['Y1', 'X1', '.', 'X2', 'X5', 'X3', '.', 'X4'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5']:
        disable_axis(axd[x_idx])
        if x_idx == "X5":
            axd[x_idx].set_xlabel("Predicted CDF", fontsize=LABEL_FONT+2, labelpad=10)


    disable_axis(axd['Y1']); axd['Y1'].set_ylabel("Empirical CDF", fontsize=LABEL_FONT+2, labelpad=10)
    # disable_axis(axd['Y2']); axd['Y2'].set_ylabel("Empirical CDF of dts", fontsize=LABEL_FONT + 2, labelpad=10)

    pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
    probs = pm.squeeze()

    subplot_ids = list('ABCDEFGH')
    for dataset_id, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        subplot_axd = axd[subplot_ids[dataset_id]]

        subplot_axd.set_xticks([0.0, 0.5, 1.0])
        subplot_axd.set_yticks([0.0, 0.5, 1.0])
        subplot_axd.set_xticks([0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9], minor=True)
        subplot_axd.set_yticks([0.1, 0.2, 0.3, 0.4, 0.6, 0.7, 0.8, 0.9], minor=True)

        if subplot_ids[dataset_id] in ['A', 'E']:
            subplot_axd.set_yticklabels(["0", "0.5", "1"])
        else:
            subplot_axd.set_yticklabels(["", "", ""])
        if subplot_ids[dataset_id] in ['E', 'F', 'G', 'H']:
            subplot_axd.set_xticklabels(["0", "0.5", "1"])
        else:
            subplot_axd.set_xticklabels(["", "", ""])

        subplot_axd.set_xlim((0, 1))
        subplot_axd.set_ylim((0, 1))

        for model_id, model_name in enumerate(MODEL_NAMES):
            print(f'Current model: {model_name}')

            if dataset == 'ehrshot' and model_name == 'AttNHP':
                continue
            with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
                dt_cdf = pickle.load(f)

            num_events = len(dt_cdf)
            indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
            pce = torch.mean(abs(indicator_eval - pm))
            print(f'PCE: {np.round(pce, 5)}')

            model_name_label = 'IFTPP' if model_name == 'IntensityFree' else model_name
            subplot_axd.plot(indicator_eval, probs, label=model_name_label, color=cmap(model_id), alpha=0.7)
            subplot_axd.plot(probs, probs, linestyle='--', color='tab:gray')
        subplot_axd.set_title(DATASET_LABELS[dataset], fontsize=LABEL_FONT + 1)
        if dataset =='amazon':
            subplot_axd.legend(loc=[-0.03, 1.3 ], ncol=7, fontsize=LABEL_FONT - 1)
    if save_fig:
        # plt.savefig('../plot/pce_all_jitter_rt.pdf', bbox_inches='tight')
        plt.savefig('../plot/pce_all_jitter_rt_final.pdf', bbox_inches='tight')
    plt.show()

plot_pce_all(save_fig=True)
# plot_pce_horizontal(save_fig=True)
