# TODO: set font, change x, y tick
# TODO: update ECE values in the main paper
# TODO: could put EHRShot in the corner so there won't be empty plot in the middle


import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns


PCE_BIN, ECE_BIN = 50, 20

FULL_PLOT_WIDTH = 6.75 # Not used, as references. TODO: check paper template
WIDTH_CAPTION_GAP = 0.1
HEIGHT_CAPTION_GAP = 0.1

WIDTH_GAP, HEIGHT_GAP = 0.1, 0.1
SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT = 36, 36
LABEL_FONT=12
cmap = plt.get_cmap('tab10')


DATASET_LABELS = {
    'amazon': 'Amazon',
    'retweet': 'Retweet',
    'taxi': 'Taxi',
    'taobao': 'Taobao',
    'stackoverflow': 'StackOverflow',
    'ehrshot': 'EHRShot',
    'nlb1rep': 'Neural Spike',
    'lastfm': 'Last.fm',
}
MODEL_NAMES = ['RMTPP', 'NHP', 'SAHP', 'THP', 'AttNHP', 'IntensityFree', 'DLHP']


def adjust(fig, left=0.0, right=1.0, bottom=0.0, top=1.0, wspace=0.0, hspace=0.0):
    fig.subplots_adjust(
        left   = left,  # the left side of the subplots of the figure
        right  = right,  # the right side of the subplots of the figure
        bottom = bottom,  # the bottom of the subplots of the figure
        top    = top,  # the top of the subplots of the figure
        wspace = wspace,  # the amount of width reserved for blank space between subplots
        hspace = hspace,  # the amount of height reserved for white space between subplots
    )

def disable_axis(ax):
    ax.set_zorder(-100)  # Avoids a visual rendering bug
    ax.set_xticks([])
    ax.set_xticklabels([])
    ax.set_yticks([])
    ax.set_yticklabels([])
    plt.setp(ax.spines.values(), color=None)





def plot_pce_all(save_fig=False):
    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A0', '.', 'B0', '.', 'C0', '.', 'D0', '.', 'E0', '.', 'F0', '.', 'G0', '.', 'H0'],
            ['Y1', 'X1', '.', 'X2', '.', 'X3', '.', 'X4', '.', 'X5', '.', 'X6', '.', 'X7', '.', 'X8'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1,
                                      WIDTH_GAP, 1,  WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT / 8),
    )

    for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8']:
        disable_axis(axd[x_idx])
        axd[x_idx].set_xlabel("Predicted CDF of dts", fontsize=LABEL_FONT)


    disable_axis(axd['Y1']); axd['Y1'].set_ylabel("Empirical CDF of dts", fontsize=LABEL_FONT, x=1, y=0.6)
    # disable_axis(axd['Y2']); axd['Y2'].set_ylabel("Pct. of samples", fontsize=LABEL_FONT, x=1, y=0.6)

    pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
    probs = pm.squeeze()

    subplot_ids = list('ABCDEFHG')
    for dataset_id, dataset in enumerate(DATASET_LABELS.keys()):
        print(f'Current dataset: {dataset}')
        subplot_axd = axd[subplot_ids[dataset_id] + '0']

        for model_id, model_name in enumerate(MODEL_NAMES):
            print(f'Current model: {model_name}')

            if dataset == 'ehrshot' and model_name == 'AttNHP':
                continue
            with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
                dt_cdf = pickle.load(f)

            num_events = len(dt_cdf)
            indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
            subplot_axd.plot(indicator_eval, probs, label=model_name, color=cmap(model_id))
            subplot_axd.plot(probs, probs, linestyle='--', color='tab:gray')
        subplot_axd.set_title(DATASET_LABELS[dataset])
        if dataset =='nlb1rep':
            subplot_axd.legend()
    if save_fig:
        plt.savefig('../plot/pce_temp_all.pdf', bbox_inches='tight')
    plt.show()



def plot_ece_all(save_fig=False):
    fig, axd = plt.subplot_mosaic(
        mosaic=[
            ['Y1', 'A0', '.', 'B0', '.', 'C0', '.', 'D0', '.', 'E0', '.', 'F0', '.', 'G0', '.', 'H0'],
            ['Y1', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y2', 'A1', '.', 'B1', '.', 'C1', '.', 'D1', '.', 'E1', '.', 'F1', '.', 'G1', '.', 'H1'],
            ['Y2', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y3', 'A2', '.', 'B2', '.', 'C2', '.', 'D2', '.', 'E2', '.', 'F2', '.', 'G2', '.', 'H2'],
            ['Y3', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y4', 'A3', '.', 'B3', '.', 'C3', '.', 'D3', '.', 'E3', '.', 'F3', '.', 'G3', '.', 'H3'],
            ['Y4', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y5', 'A4', '.', 'B4', '.', 'C4', '.', 'D4', '.', 'E4', '.', 'F4', '.', 'G4', '.', 'H4'],
            ['Y5', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y6', 'A5', '.', 'B5', '.', 'C5', '.', 'D5', '.', 'E5', '.', 'F5', '.', 'G5', '.', 'H5'],
            ['Y6', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.', '.'],
            ['Y7', 'A6', '.', 'B6', '.', 'C6', '.', 'D6', '.', 'E6', '.', 'F6', '.', 'G6', '.', 'H6'],
            ['Y7', 'X1', '.', 'X2', '.', 'X3', '.', 'X4', '.', 'X5', '.', 'X6', '.', 'X7', '.', 'X8'],
        ],
        gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1,
                                      WIDTH_GAP, 1,  WIDTH_GAP, 1, WIDTH_GAP, 1, WIDTH_GAP, 1],
                     "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_GAP,
                                       1, HEIGHT_GAP, 1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
        figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
    )

    for x_idx in ['X1', 'X2', 'X3', 'X4', 'X5', 'X6', 'X7', 'X8']:
        disable_axis(axd[x_idx])
        axd[x_idx].set_xlabel("Confidence", fontsize=LABEL_FONT)


    for y_idx in ['Y1', 'Y2', 'Y3', 'Y4', 'Y5', 'Y6', 'Y7']:
        disable_axis(axd[y_idx]);
        axd[y_idx].set_ylabel("Accuracy", fontsize=LABEL_FONT, x=1, y=0.6)

    prob_bins = np.linspace(0, 1, ECE_BIN + 1)
    subplot_ids = list('ABCDEFHG')
    for model_id, model_name in enumerate(MODEL_NAMES):
        print(f'Current model: {model_name}')
        for dataset_id, dataset in enumerate(DATASET_LABELS.keys()):
            print(f'Current dataset: {dataset}')

            subplot_axd = axd[subplot_ids[dataset_id] + str(model_id)]
            if dataset == 'ehrshot' and model_name == 'AttNHP':
                continue
            with open(f'../checkpoints/{dataset}/{model_name}/mark_conf.pkl', 'rb') as f:
                mark_conf = pickle.load(f)
            with open(f'../checkpoints/{dataset}/{model_name}/mark_pred.pkl', 'rb') as f:
                mark_pred = pickle.load(f)
            with open(f'../checkpoints/{dataset}/{model_name}/true_mark.pkl', 'rb') as f:
                true_mark = pickle.load(f)

            # num_events = len(true_mark)
            true_mark_np = np.array(true_mark)
            mark_pred_np = np.array(mark_pred)
            mark_conf_np = np.array(mark_conf)
            confidence = []
            accuracy = []
            for i in range(ECE_BIN):
                mark_mask_i = (prob_bins[i] <= mark_conf_np) & (mark_conf_np < prob_bins[i + 1])
                true_mark_i = true_mark_np[mark_mask_i]
                mark_pred_i = mark_pred_np[mark_mask_i]
                # mark_conf_i = mark_conf_np[mark_mask_i]

                acc = np.mean(true_mark_i == mark_pred_i) if sum(mark_mask_i) != 0 else 0  # no bar if no samples

                confidence.append((prob_bins[i] + prob_bins[i + 1]) / 2)
                accuracy.append(acc)

            # subplot_axd.plot(x=confidence, height=accuracy, label=model_name, color=cmap(model_id))
            subplot_axd.set_ylim([0, 1])
            subplot_axd.bar(x=confidence, height=accuracy, label=model_name, color=cmap(model_id), alpha=0.3,
                            width=1 / ECE_BIN)
            # subplot_axd.hist(x=confidence, height=frequency, color=cmap(model_id))

            ax_right = subplot_axd.twinx()
            # sns.kdeplot(mark_conf, ax=ax_right, clip=[0,1])  # smaller bw_adjust -> less smoothing
            sns.kdeplot(mark_conf, ax=ax_right, cut=0, color=cmap(model_id))
            subplot_axd.plot(prob_bins, prob_bins, linestyle='--', color='tab:gray')
            subplot_axd.legend()
            subplot_axd.set_title(DATASET_LABELS[dataset])
    if save_fig:
        plt.savefig(f'../plot/ece_temp.pdf', bbox_inches='tight')
    plt.show()



plot_pce_all(save_fig=True)

# plot_ece_all(save_fig=True)











# # Old code

# def plot_pce_subplot(subplot_axd, dataset):
#     pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
#     probs = pm.squeeze()
#     for i, model_name in enumerate(MODEL_NAMES):
#         if dataset == 'ehrshot' and model_name == 'AttNHP':
#             continue
#         with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
#             dt_cdf = pickle.load(f)
#
#         num_events = len(dt_cdf)
#         indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
#         subplot_axd.plot(probs, indicator_eval, label=model_name, color=cmap(i))
#     subplot_axd.plot(probs, probs, linestyle='--', color='tab:gray')
#     subplot_axd.legend()
#     # plt.xlabel('True CDF (of dts)')
#     # plt.ylabel('Estimated CDF (of dts)')
#     subplot_axd.set_title(DATASET_LABELS[dataset])


# def plot_ece_subplot(subplot_axd, dataset, model_name, model_id):
#     # compute ECE, weighted by number of predictions in each bin
#     prob_bins = np.linspace(0, 1, ECE_BIN + 1)
#
#     # for i, model_name in enumerate(MODEL_NAMES):
#     if dataset == 'ehrshot' and model_name == 'AttNHP':
#         return  # TODO: change this outside
#     with open(f'../checkpoints/{dataset}/{model_name}/mark_conf.pkl', 'rb') as f:
#         mark_conf = pickle.load(f)
#     with open(f'../checkpoints/{dataset}/{model_name}/mark_pred.pkl', 'rb') as f:
#         mark_pred = pickle.load(f)
#     with open(f'../checkpoints/{dataset}/{model_name}/true_mark.pkl', 'rb') as f:
#         true_mark = pickle.load(f)
#
#     num_events = len(true_mark)
#     true_mark_np = np.array(true_mark)
#     mark_pred_np = np.array(mark_pred)
#     mark_conf_np = np.array(mark_conf)
#     confidence = []
#     accuracy = []
#     frequency = []
#     for i in range(ECE_BIN):
#         # print(prob_bins[i])
#         mark_mask_i = (prob_bins[i] <= mark_conf_np) & (mark_conf_np < prob_bins[i + 1])
#         true_mark_i = true_mark_np[mark_mask_i]
#         mark_pred_i = mark_pred_np[mark_mask_i]
#         # mark_conf_i = mark_conf_np[mark_mask_i]
#
#         frequency.append(sum(mark_mask_i) / num_events)
#         acc = np.mean(true_mark_i == mark_pred_i) if sum(mark_mask_i) != 0 else 0  # TODO
#
#         confidence.append((prob_bins[i] + prob_bins[i + 1])/2)
#         accuracy.append(acc)
#
#         # if sum(mark_mask_i) != 0:
#         #     acc = np.mean(true_mark_i == mark_pred_i)
#         #     confidence.append((prob_bins[i] + prob_bins[i + 1]) / 2)
#         #     accuracy.append(acc)
#
#     # subplot_axd.plot(x=confidence, height=accuracy, label=model_name, color=cmap(model_id))
#     subplot_axd.set_ylim([0, 1])
#     subplot_axd.bar(x=confidence, height=accuracy, label=model_name, color=cmap(model_id), alpha=0.3, width=1/ECE_BIN)
#     # subplot_axd.hist(x=confidence, height=frequency, color=cmap(model_id))
#
#     ax_right = subplot_axd.twinx()
#     # sns.kdeplot(mark_conf, ax=ax_right, clip=[0,1])  # smaller bw_adjust -> less smoothing
#     sns.kdeplot(mark_conf, ax=ax_right, cut=0)
#     subplot_axd.plot(prob_bins, prob_bins, linestyle='--', color='tab:gray')
#     subplot_axd.legend()
#     subplot_axd.set_title(DATASET_LABELS[dataset])
#



# def plot_pce(save_fig=False):
#     fig, axd = plt.subplot_mosaic(
#         mosaic=[
#             ['Y1', "A", '.', "B"],
#             ['Y1', '.', '.', '.'],
#             ['Y2', "C", '.', "D"],
#             ['Y2', '.', '.', '.'],
#             ['Y3', "E", '.', "F"],
#             ['Y3', '.', '.', '.'],
#             ['Y4', "G", '.', "H"],
#             ['Y4', 'X1', '.', 'X2'],
#         ],
#         gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1],
#                      "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_GAP,
#                                        1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
#         figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
#     )
#
#     disable_axis(axd["X1"]); axd["X1"].set_xlabel("True CDF", fontsize=LABEL_FONT)
#     disable_axis(axd["X2"]); axd["X2"].set_xlabel("True CDF", fontsize=LABEL_FONT)
#     disable_axis(axd["Y1"]); axd["Y1"].set_ylabel("Est. CDF", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y2"]); axd["Y2"].set_ylabel("Est. CDF", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y3"]); axd["Y3"].set_ylabel("Est. CDF", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y4"]); axd["Y4"].set_ylabel("Est. CDF", fontsize=LABEL_FONT, x=1, y=0.6)
#
#     # pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
#     # probs = pm.squeeze()
#     plot_pce_subplot(axd["A"], 'amazon')
#     plot_pce_subplot(axd["B"], 'retweet')
#     plot_pce_subplot(axd["C"], 'taxi')
#     plot_pce_subplot(axd["D"], 'taobao')
#     plot_pce_subplot(axd["E"], 'stackoverflow')
#     plot_pce_subplot(axd["F"], 'ehrshot')
#     plot_pce_subplot(axd["G"], 'nlb1rep')
#     plot_pce_subplot(axd["H"], 'lastfm')
#
#     if save_fig:
#         plt.savefig('../plot/pce_temp.pdf', bbox_inches='tight')
#     plt.show()




# def plot_ece(model_name, model_id, save_fig=False):
#     fig, axd = plt.subplot_mosaic(
#         mosaic=[
#             ['Y1', "A", '.', "B"],
#             ['Y1', '.', '.', '.'],
#             ['Y2', "C", '.', "D"],
#             ['Y2', '.', '.', '.'],
#             ['Y3', "E", '.', "F"],
#             ['Y3', '.', '.', '.'],
#             ['Y4', "G", '.', "H"],
#             ['Y4', 'X1', '.', 'X2'],
#         ],
#         gridspec_kw={"width_ratios": [WIDTH_CAPTION_GAP, 1, WIDTH_GAP, 1],
#                      "height_ratios": [1, HEIGHT_GAP, 1, HEIGHT_GAP,
#                                        1, HEIGHT_GAP, 1, HEIGHT_CAPTION_GAP]},
#         figsize=(SUB_PLOT_WIDTH, SUB_PLOT_HEIGHT),
#     )
#
#     disable_axis(axd["X1"]); axd["X1"].set_xlabel("Confidence", fontsize=LABEL_FONT)
#     disable_axis(axd["X2"]); axd["X2"].set_xlabel("Confidence", fontsize=LABEL_FONT)
#     disable_axis(axd["Y1"]); axd["Y1"].set_ylabel("Accuracy", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y2"]); axd["Y2"].set_ylabel("Accuracy", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y3"]); axd["Y3"].set_ylabel("Accuracy", fontsize=LABEL_FONT, x=1, y=0.6)
#     disable_axis(axd["Y4"]); axd["Y4"].set_ylabel("Accuracy", fontsize=LABEL_FONT, x=1, y=0.6)
#
#     plot_ece_subplot(axd["A"], 'amazon', model_name, model_id)
#     plot_ece_subplot(axd["B"], 'retweet', model_name, model_id)
#     plot_ece_subplot(axd["C"], 'taxi', model_name, model_id)
#     plot_ece_subplot(axd["D"], 'taobao', model_name, model_id)
#     plot_ece_subplot(axd["E"], 'stackoverflow', model_name, model_id)
#     plot_ece_subplot(axd["F"], 'ehrshot', model_name, model_id)
#     plot_ece_subplot(axd["G"], 'nlb1rep', model_name, model_id)
#     plot_ece_subplot(axd["H"], 'lastfm', model_name, model_id)
#     if save_fig:
#         plt.savefig(f'./ece_temp_{model_name}.pdf', bbox_inches='tight')
#     plt.show()






# def get_est_cdf(dataset, model_name):
#     with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
#         dt_cdf = pickle.load(f)
#
#     num_events = len(dt_cdf)
#     est_cdf = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
#     return est_cdf


#
# plt.figure(figsize=(6,6))
# pm = torch.linspace(0, 1, PCE_BIN + 1)[1:][None, ...]
# probs = pm.squeeze()
#
# for i, model_name in enumerate(MODEL_NAMES):
#     if dataset == 'ehrshot' and model_name == 'AttNHP':
#         continue
#     with open(f'../checkpoints/{dataset}/{model_name}/dt_cdf.pkl', 'rb') as f:
#         dt_cdf = pickle.load(f)
#
#     num_events = len(dt_cdf)
#     # compute PCE, Eq.61
#     indicator_eval = (torch.tensor(dt_cdf)[..., None] <= pm).int().sum(dim=0) / num_events
#     # pce = torch.mean(abs(indicator_eval - pm))
#     # print(f'PCE: {np.round(pce, 5)}')
#
#     plt.plot(probs, indicator_eval, label=model_name, color=cmap(i))
# plt.plot(probs, probs, linestyle='--', color='tab:gray')
# plt.legend()
# plt.xlabel('True CDF (of dts)')
# plt.ylabel('Estimated CDF (of dts)')
# plt.title(dataset)
# plt.show()