import os
import numpy as np
import pickle as pkl
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import functools
import operator
from color import cmap


# datasets = ['stackoverflow']
datasets = ['retweet', 'stackoverflow', 'usearthquake', 'yelp', 'taobao', 'bookorder']
# datasets = ['yelp']

dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}

def functools_reduce_iconcat(a):
    return functools.reduce(operator.iconcat, a, [])

for selected_dataset in datasets:
    mae_e_data_location = os.path.join('mae_e_by_time_event', '1', f'test_mae_e_by_time_event_{selected_dataset}.pkl')
    f_data = open(mae_e_data_location, 'rb')
    data = pkl.load(f_data)
    f_data.close()

    events_pred_index = data['events_pred_index']
    events_pred_index = functools_reduce_iconcat(functools_reduce_iconcat(functools_reduce_iconcat(events_pred_index)))

    mae_e_split_by_marks = [[] for _ in range(dict_num_marks[selected_dataset])]
    for mae_e_per_seq, event_next_per_seq in zip(data['mae_e'], data['event_next']):
        for each_mae_e, each_next_event in zip(mae_e_per_seq, event_next_per_seq):
            mae_e_split_by_marks[each_next_event].append(each_mae_e)
    
    del data

    text = os.path.join('mae_e_by_time_event', f'{selected_dataset}_mae_e_by_marks.txt')
    f_text = open(text, 'w')
    f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')
    for mark_idx, sub_mae_e_list in enumerate(mae_e_split_by_marks):
        tmp_sub_mae_e_list = np.array(sub_mae_e_list)
        f_text.write(f'For mark {mark_idx}, we have {tmp_sub_mae_e_list.shape[0]} mae_e records.\n')
        values = [*np.percentile(tmp_sub_mae_e_list, [25, 50, 75]), tmp_sub_mae_e_list.mean()]
        f_text.write('Q1: {0}, Q2: {1}, Q3: {2}, mean: {3}. \n'.format(*values))


    dict_sampled_events = {
        'Event Types': [f'{idx}' for idx in range(dict_num_marks[selected_dataset])],
        'The number of samples': [0,] * dict_num_marks[selected_dataset]
    }
    found_events, event_counts = np.unique(events_pred_index, return_counts = True)
    for found_event, event_count in zip(found_events, event_counts):
        dict_sampled_events['The number of samples'][found_event] = event_count
    df_all_sampled_events = pd.DataFrame.from_dict(dict_sampled_events)
    df_all_sampled_events['Percentage'] = \
        df_all_sampled_events['The number of samples'] / df_all_sampled_events['The number of samples'].sum()
    
    plt.rcParams.update({'font.size': 30, 'figure.figsize': (9, 9)})
    fig = plt.figure()
    ax = sns.set_palette(cmap, n_colors = dict_num_marks[selected_dataset])
    ax = sns.barplot(x = 'Event Types', y = 'Percentage', hue = 'Event Types', data = df_all_sampled_events, ax = ax)
    if selected_dataset in ['stackoverflow', 'taobao']:
        ax.xaxis.set_tick_params(labelsize = 18)
    # ax.bar_label(ax.containers[0], fontsize = 12)
    plt.savefig(os.path.join('mae_e_by_time_event', f'time_event_mark_distribution_{selected_dataset}.png'), \
                             dpi = 1000, bbox_inches = "tight")
