import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import seaborn as sns
import os
import pandas as pd
from einops import rearrange
from color import cmap


# parameters
history_length_ratio = 0.4
number_of_bins = 50
dataset = 'taobao'
font_size = 30
fig_size = (9, 9)
number_of_attempts = 10

# Use this to clamp the sampled timestamp if nonsense happens.
# IFN will generate 10^6 for events that is highly probably not happen.
clamp_sampled_time = False
min_time = 0
max_time = 14

# other constants
epsilon = 1e-20

for attempt_idx in range(number_of_attempts):
    # load data
    f_data = open(os.path.join(dataset, f'sampled_time_{dataset}_{history_length_ratio}_{attempt_idx}.pkl'), 'rb')
    info = pkl.load(f_data)
    f_data.close()
    sampled_time = info['sampled_times'].squeeze()
    p_m = info['p_m'].squeeze().mean(axis = 0)

    if dataset == 'volcano':
        sampled_time = rearrange(sampled_time, 'a -> a ()')
    sampled_time = rearrange(sampled_time, 'a b -> b a')
    mark_dimension, the_number_of_samples = sampled_time.shape
    mark = range(mark_dimension)
    dfs = []

    plt.rcParams.update({'font.size': font_size})
    fig = plt.figure(figsize=fig_size)
    for ys, p, mark_idx in zip(sampled_time, p_m, mark):
        if clamp_sampled_time:
            ys = np.clip(ys, a_min = min_time, a_max = max_time)

        hist, bin_edges = np.histogram(ys, bins = number_of_bins, density = True)
        hist = hist * p
        hist = np.clip(hist, a_min = -10, a_max = 1e6)
        # Plot the bar graph given by xs and ys on the plane y = k with 80% opacity.

        dfs.append(pd.DataFrame.from_dict({
            'Time': bin_edges[:-1],
            'Probability': hist,
            'Mark': [f'Mark {mark_idx}', ] * hist.shape[0]
        }))
    
    df = pd.concat(dfs, ignore_index = False, axis = 0)
    ax = sns.set_palette(cmap, n_colors = mark_dimension)
    ax = sns.lineplot(data = df, x = 'Time', y = 'Probability', hue = 'Mark', ax = ax, linewidth = 4)
    if dataset in ['stackoverflow', 'taobao']:
        ax.legend(loc='upper right', ncol = 2, title = "Mark", prop = {'size': 18})
    plt.savefig(os.path.join(dataset, f'prediction_distribution_{dataset}_{history_length_ratio}_{attempt_idx}_2d.png'), \
                dpi = 1000, bbox_inches = "tight")