import matplotlib.pyplot as plt
import os
import numpy as np
import pickle as pkl
import seaborn as sns
import pandas as pd
from einops import rearrange

from color import cmap


# parameters
history_length_ratio = 0.4
dataset = 'stackoverflow'
attempts = 10

# Use this to clamp the sampled timestamp if nonsense happens.
# IFN will generate 10^6 for events that is highly probably not happen.
clamp_sampled_time = True
min_time = 0
max_time = 14

# load data
for attempt_idx in range(attempts):
    print(attempt_idx)
    f_data = open(os.path.join(dataset, f'which_event_first_{dataset}_{history_length_ratio}_{attempt_idx}.pkl'), 'rb')
    info = pkl.load(f_data)
    f_data.close()

    event_history = info['event_history'].squeeze()
    time_history = info['time_history'].squeeze()
    sampled_time = info['sampled_time'].squeeze()
    if dataset == 'volcano':
        sampled_time = rearrange(sampled_time, 'a -> a ()')
    sampled_time = rearrange(sampled_time, 'a b -> b a')
    mark_dimension, the_number_of_samples = sampled_time.shape

    df_sampled_event = []
    for mark_idx, sampled_time_for_each_mark in enumerate(sampled_time):
        if clamp_sampled_time:
            sampled_time_for_each_mark = np.clip(sampled_time_for_each_mark, a_min = min_time, a_max = max_time)

        df_sampled_event.append(pd.DataFrame.from_dict({'Time': sampled_time_for_each_mark, \
                                                        'Event': [f'Event {mark_idx}',] * the_number_of_samples}))
    
    df_sampled_event = pd.concat(df_sampled_event)
    
    plt.rcParams.update({'font.size': 30, 'figure.figsize': (9, 9)})
    fig = plt.figure()
    ax = sns.set_palette(cmap, n_colors = mark_dimension)
    ax = sns.boxplot(x = 'Time', y = 'Event', hue = 'Event', data = df_sampled_event, ax = ax)
    # if dataset in ['stackoverflow', 'taobao']:
    #     ax.legend(loc='upper right', ncol = 2, title = "Mark", prop = {'size': 18})
    plt.savefig(os.path.join(dataset, f'which_event_will_happen_{dataset}_{history_length_ratio}_{attempt_idx}.png'), \
                dpi = 1000, bbox_inches = "tight")