import os
import numpy as np
import pickle as pkl
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

root_path = os.path.dirname(os.path.abspath(__file__))

# datasets = ['bookorder', 'retweet', 'stackoverflow', 'taobao', 'usearthquake', 'yelp']
datasets = ['stackoverflow']
model_name = ['ifn']
retries = ['1', '2', '3']


dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}

freq_marks = {
    'bookorder': [0],
    'retweet': [0, 1],
    'stackoverflow': [0, 3, 4, 5, 8, ],
    'taobao': [16, ],
    'usearthquake': [0, 1, 2],
    'yelp': [1, ]
}

rare_marks = {
    'bookorder': [1],
    'retweet': [2, ],
    'stackoverflow': [1, 2, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
    'taobao': list(range(16)),
    'usearthquake': [3, 4, 5, 6],
    'yelp': [0, 2]
}

for selected_dataset in datasets:
    for selected_model in model_name:
        results_freq, results_rare = [], []
        text = os.path.join(root_path, 'mae_e_data', selected_model, f'{selected_dataset}_pm_freq.txt')
        f_text = open(text, 'w')
        f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')
        dataset_size_for_each_mark = [0 for _ in range(dict_num_marks[selected_dataset])]

        for retry_idx in retries:
            mae_e_data_location = os.path.join(root_path, 'mae_e_data', selected_model, retry_idx, f'test_mae_e_data_{selected_dataset}.pkl')
            f_data = open(mae_e_data_location, 'rb')
            data = pkl.load(f_data)
            f_data.close()
            
            pm_value_while_freq_happens = []
            pm_value_while_rare_happens = []
            for event_next_per_seq, p_m_per_seq in zip(data['events_next'], data['pm']):
                p_m_per_seq = p_m_per_seq[0]
                for each_next_event, pm in zip(event_next_per_seq, p_m_per_seq):
                    if each_next_event in freq_marks[selected_dataset]:
                        pm_value_while_freq_happens.append(pm[each_next_event])
                    else:
                        pm_value_while_rare_happens.append(pm[each_next_event])
            
            pm_arrays = [np.array(sublist).squeeze() for sublist in data['pm']]
            reshaped_pm = np.concatenate(pm_arrays, axis = 0)
            freq_reshaped_pm = reshaped_pm[:, freq_marks[selected_dataset]].flatten()
            rare_reshaped_pm = reshaped_pm[:, rare_marks[selected_dataset]].flatten()
            freq_size = freq_reshaped_pm.shape[0]
            rare_size = rare_reshaped_pm.shape[0]

            plt.rcParams.update({'font.size': 30})
            if freq_size > 0:
                fig = plt.figure(figsize = (9, 9))
                ax = sns.kdeplot(data = pm_value_while_freq_happens, label = 'selected_pm', \
                                 clip = (0, 1), common_norm = False, common_grid = True)
                sns.kdeplot(data = freq_reshaped_pm, ax = ax, label = 'all_pm', \
                            clip = (0, 1), common_norm = False, common_grid = True)
                ax.legend(loc='lower right', ncol = 1, title = "Mark", prop = {'size': 18})
                plt.savefig(os.path.join(root_path, selected_dataset, \
                                         f'pm_freq_distribution_{retry_idx}.png'), \
                                         dpi = 1000, bbox_inches = "tight")
                plt.close(fig = fig)

            if rare_size > 0:
                fig = plt.figure(figsize = (9, 9))
                ax = sns.kdeplot(data = pm_value_while_rare_happens, label = 'selected_pm', \
                                 common_norm = False, common_grid = True, clip = (0, 1))
                sns.kdeplot(data = rare_reshaped_pm, ax = ax, label = 'all_pm', \
                            clip = (0, 1), common_norm = False, common_grid = True)
                ax.legend(loc='lower right', ncol = 1, title = "Mark", prop = {'size': 18})
                plt.savefig(os.path.join(root_path, selected_dataset, \
                                         f'pm_rare_distribution_{retry_idx}.png'), \
                                         dpi = 1000, bbox_inches = "tight")
                plt.close(fig = fig)


            percentage_of_freq_mark = []
            percentage_of_rare_mark = []
            q123_freq, q123_rare, mean_freq, mean_rare = None, None, None, None
            if freq_size > 0:
                for value in pm_value_while_freq_happens:
                    percentage_of_freq_mark.append((value > freq_reshaped_pm).astype(int).sum() / freq_size)
                q123_freq = np.percentile(percentage_of_freq_mark, [25, 50, 75], axis = -1)
                mean_freq = np.mean(percentage_of_freq_mark)
                results_freq.append([*q123_freq, mean_freq])
            if rare_size > 0:
                for value in pm_value_while_rare_happens:
                    percentage_of_rare_mark.append((value > rare_reshaped_pm).astype(int).sum() / rare_size)
                q123_rare = np.percentile(percentage_of_rare_mark, [25, 50, 75], axis = -1)
                mean_rare = np.mean(percentage_of_rare_mark)
                results_rare.append([*q123_rare, mean_rare])
        
        mean_results_freq = np.mean(results_freq, axis = 0)
        mean_results_rare = np.mean(results_rare, axis = 0)
        std_results_freq = np.std(results_freq, axis = 0)
        std_results_rare = np.std(results_rare, axis = 0)

        if len(results_freq) > 0:
            f_text.write('Frequent events:')
            f_text.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7} \n'.format(*mean_results_freq, *std_results_freq))
        if len(results_rare) > 0:
            f_text.write('Rare events:')
            f_text.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7} \n'.format(*mean_results_rare, *std_results_rare))