import os
import lzma
import seaborn as sns
import numpy as np
import pandas as pd
import pickle as pkl
from einops import rearrange
import matplotlib.pyplot as plt
from color import cmap


number_of_bins = 200
font_size = 30
fig_size = (9, 9)

# Use this to clamp the sampled timestamp if nonsense happens.
# IFN will generate 10^6 for events that is highly probably not happen.
clamp_sampled_time = True
min_time = 0

# stackoverflow: 35
# usearthquake: 35
max_time = 100

root_path = os.path.dirname(os.path.abspath(__file__))
samples_file = 'test_samples_for_every_point.pkl.lzma'
# datasets = ['stackoverflow', 'yelp', 'retweet', 'usearthquake']
datasets = ['stackoverflow']

selected_dataset = datasets[0]

f_data = lzma.open(os.path.join(root_path, selected_dataset, samples_file), 'rb')
data = pkl.load(f_data)
f_data.close()

samples = data['samples']                      # samples[0]: [num_samples, batch_size, seq_len, num_mark]
p_ms = data['p_ms']                            # p_ms[0]: [batch_size, seq_len, num_mark]
packed_data = zip(samples, p_ms)

max_ = 0
for sampled_time in samples:
    sampled_time = np.array(sampled_time)
    sampled_time = sampled_time.squeeze()
    for ys in sampled_time:
        if max_ < ys.max():
            max_ = ys.max()

time_probe = np.linspace(min_time if clamp_sampled_time else 0, max_time if clamp_sampled_time else max_, number_of_bins)
mark_dimension = 0

for sample, p_m in packed_data:
    np_sample = np.array(sample)
    p_m = np.array(p_m)
    sampled_time = rearrange(np_sample, 'ns b s nm -> nm b s ns')
    p_m = rearrange(p_m, 'b s nm -> nm b s')
    mark_dimension, batch_size, seq_len, the_number_of_samples = sampled_time.shape
    mark = range(mark_dimension)
    dfs = []

    for ys, p, mark_idx in zip(sampled_time, p_m, mark):
        if clamp_sampled_time:
            ys = np.clip(ys, a_min = min_time, a_max = max_time)
        ys = ys.squeeze()
        p = p.squeeze()

        hists = []
        bin_edges = []
        for y in ys:
            hist, bin_edge = np.histogram(y, bins = time_probe, density = True)
            hists.append(hist)
            bin_edges.append(bin_edge[:-1])
        
        hists = np.stack(hists, axis = -1)
        bin_edges = np.concatenate(bin_edges, axis = -1)
        hists = hists * p
        hists = np.clip(hists, a_min = -10, a_max = 1e6).flatten()
        # Plot the bar graph given by xs and ys on the plane y = k with 80% opacity.
    
        dfs.append(pd.DataFrame.from_dict({
            'Time': bin_edges,
            'Probability': hists,
            'Mark': [f'Mark {mark_idx}', ] * hists.shape[0]
        }))
        
plt.rcParams.update({'font.size': font_size})
fig = plt.figure(figsize=fig_size)
df = pd.concat(dfs, ignore_index = False, axis = 0)
print(mark_dimension)
ax = sns.set_palette(cmap, n_colors = mark_dimension)
ax = sns.lineplot(data = df, x = 'Time', y = 'Probability', hue = 'Mark', 
                  ax = ax, linewidth = 2)
if selected_dataset in ['retweet']:
    ax.xaxis.set_tick_params(labelsize = 26)
if selected_dataset in ['stackoverflow', 'taobao']:
    ax.legend(loc='upper right', ncol = 2, title = "Mark", prop = {'size': 14})
plt.savefig(os.path.join(root_path, selected_dataset, f'prediction_distribution_2d_all.png'), \
            dpi = 1000, bbox_inches = "tight")