import matplotlib.pyplot as plt
import numpy as np
import pickle as pkl
import seaborn as sns
import os
from einops import rearrange

# parameters
history_length_ratio = 0.4
number_of_bins = 25
dataset = 'usearthquake'
log_scale = False
number_of_attempts = 10

# Use this to clamp the sampled timestamp if nonsense happens.
# IFN will generate 10^6 for events that is highly probably not happen.
clamp_sampled_time = False
min_time = 0
max_time = 14

# other constants
epsilon = 1e-20

for attempt_idx in range(number_of_attempts):
    # load data
    f_data = open(os.path.join(dataset, f'sampled_time_{dataset}_{history_length_ratio}_{attempt_idx}.pkl'), 'rb')
    info = pkl.load(f_data)
    f_data.close()
    sampled_time = info['sampled_times'].squeeze()
    p_m = info['p_m'].squeeze().mean(axis = 0)

    if dataset == 'volcano':
        sampled_time = rearrange(sampled_time, 'a -> a ()')
    sampled_time = rearrange(sampled_time, 'a b -> b a')
    mark_dimension, the_number_of_samples = sampled_time.shape
    
    # plt.rcParams.update({'font.size': 12})
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(projection='3d')
    colors = sns.husl_palette(mark_dimension)
    yticks = np.arange(mark_dimension)
    for c, k, ys, p in zip(colors, yticks, sampled_time, p_m):
        if clamp_sampled_time:
            ys = np.clip(ys, a_min = min_time, a_max = max_time)

        hist, bin_edges = np.histogram(ys, bins = number_of_bins, density = True)
        hist = hist * p
        width = np.diff(bin_edges).mean() * 0.8
        if log_scale:
            hist = np.log(hist + epsilon)
        hist = np.clip(hist, a_min = -10, a_max = 1e6)
        # Plot the bar graph given by xs and ys on the plane y = k with 80% opacity.
        ax.bar(bin_edges[:-1], hist if log_scale else hist, zs = k, zdir = 'y', width = width, color = [c], alpha = 0.8, label = f'Mark {k}')
    
    ax.set_xlabel('Time')
    ax.set_ylabel('Mark')
    ax.set_zlabel('log-probability' if log_scale else 'probability')
    plt.legend(loc="upper right")
    
    # On the y-axis let's only label the discrete values that we have data for.
    ax.set_yticks(yticks)
    plt.savefig(os.path.join(dataset, f'prediction_distribution_{dataset}_{history_length_ratio}_{attempt_idx}.png'), dpi = 1000)