import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import seaborn as sns
import os
import pandas as pd
from einops import rearrange
from color import cmap

root_path = os.path.dirname(os.path.abspath(__file__))

# parameters
history_length_ratio = 0.4
number_of_bins = 50
dataset = 'usearthquake'
font_size = 30
fig_size = (9, 9)
number_of_attempts = 10

# Use this to clamp the sampled timestamp if nonsense happens.
# IFN will generate 10^6 for events that is highly probably not happen.
clamp_sampled_time = False
min_time = 0
max_time = 14

# other constants
epsilon = 1e-20
mark_dimension = 0

max_ = 0
for attempt_idx in range(number_of_attempts):
    # load data
    f_data = open(os.path.join(root_path, dataset, f'sampled_time_{dataset}_{history_length_ratio}_{attempt_idx}.pkl'), 'rb')
    info = pkl.load(f_data)
    f_data.close()
    sampled_time = info['sampled_times'].squeeze()

    if dataset == 'volcano':
        sampled_time = rearrange(sampled_time, 'a -> a ()')
    sampled_time = rearrange(sampled_time, 'a b -> b a')
    mark_dimension, the_number_of_samples = sampled_time.shape
    for ys in sampled_time:
        if max_ < ys.max():
            max_ = ys.max()

bins = np.linspace(min_time if clamp_sampled_time else 0, max_time if clamp_sampled_time else max_, number_of_bins)
all_data = []
for attempt_idx in range(number_of_attempts):
    # load data
    f_data = open(os.path.join(root_path, dataset, f'sampled_time_{dataset}_{history_length_ratio}_{attempt_idx}.pkl'), 'rb')
    info = pkl.load(f_data)
    f_data.close()
    sampled_time = info['sampled_times'].squeeze()
    p_m = info['p_m'].squeeze().mean(axis = 0)

    if dataset == 'volcano':
        sampled_time = rearrange(sampled_time, 'a -> a ()')
    sampled_time = rearrange(sampled_time, 'a b -> b a')
    mark_dimension, the_number_of_samples = sampled_time.shape
    mark = range(mark_dimension)
    dfs = []

    for ys, p, mark_idx in zip(sampled_time, p_m, mark):
        if clamp_sampled_time:
            ys = np.clip(ys, a_min = min_time, a_max = max_time)

        hist, bin_edges = np.histogram(ys, bins = bins, density = True)
        hist = hist * p
        hist = np.clip(hist, a_min = -10, a_max = 1e6)
        # Plot the bar graph given by xs and ys on the plane y = k with 80% opacity.

        dfs.append(pd.DataFrame.from_dict({
            'Time': bin_edges[:-1],
            'Probability': hist,
            'Mark': [f'Mark {mark_idx}', ] * hist.shape[0]
        }))
    
    df = pd.concat(dfs, ignore_index = False, axis = 0)
    all_data.append(df)

plt.rcParams.update({'font.size': font_size})
fig = plt.figure(figsize = fig_size)
df_all_data = pd.concat(all_data, ignore_index = False, axis = 0)
ax = sns.set_palette(cmap, n_colors = mark_dimension)
ax = sns.lineplot(data = df_all_data, x = 'Time', y = 'Probability', hue = 'Mark', ax = ax, linewidth = 2)
if dataset == 'stackoverflow':
    ax.legend(loc='upper right', ncol = 2, title = "Mark", prop = {'size': 18})
plt.savefig(os.path.join(root_path, dataset, f'prediction_distribution_{dataset}_{history_length_ratio}_2d_all.png'), \
            dpi = 1000, bbox_inches = "tight")