import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import os
import pandas as pd

from color import cmap


# parameters
# stackoverflow: (12, 9)
dataset = 'yelp'
event_type = 3

# load the full dataset
dataset_filenames = ['train.json', 'evaluate.json', 'test.json']
# dataset_filenames = ['test.json']
datasets = []
for dataset_filename in dataset_filenames:
    datasets.append(pd.read_json(os.path.join(dataset, dataset, dataset_filename)))

full_dataset = pd.concat(datasets, ignore_index = True)
events_seqs = full_dataset.event.values

all_events = []
for events_seq in events_seqs:
    all_events += events_seq

dict_all_events = {
    'Event Types': [f'{idx}' for idx in range(event_type)],
    'The number of samples': [0,] * event_type
}

found_events, event_counts = np.unique(all_events, return_counts = True)
for found_event, event_count in zip(found_events, event_counts):
    dict_all_events['The number of samples'][found_event] = event_count

df_all_events = pd.DataFrame.from_dict(dict_all_events)
df_all_events['Percentage'] = \
        df_all_events['The number of samples'] / df_all_events['The number of samples'].sum()
 
plt.rcParams.update({'font.size': 30, 'figure.figsize': (9, 9)})
fig = plt.figure()
ax = sns.set_palette(cmap, n_colors = event_type)
ax = sns.barplot(x = 'Event Types', y = 'Percentage', hue = 'Event Types', data = df_all_events, ax = ax)
if dataset in ['stackoverflow', 'taobao']:
    ax.xaxis.set_tick_params(labelsize = 18)
plt.savefig(os.path.join(dataset, f'{dataset}_mark_distribution.png'), dpi = 1000, bbox_inches = "tight")