import os
import numpy as np
import pickle as pkl

root_path = os.path.dirname(os.path.abspath(__file__))

datasets = ['bookorder', 'retweet', 'stackoverflow', 'taobao', 'usearthquake', 'yelp']
# datasets = ['stackoverflow']
model_name = ['ifn']
retries = ['1', '2', '3']


dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}

for selected_dataset in datasets:
    for selected_model in model_name:
        results = {idx: [] for idx in range(dict_num_marks[selected_dataset])}
        text = os.path.join(root_path, 'mae_e_data', selected_model, f'{selected_dataset}_pm.txt')
        f_text = open(text, 'w')
        f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')
        write_mark_size = False
        dataset_size_for_each_mark = [0 for _ in range(dict_num_marks[selected_dataset])]

        for retry_idx in retries:
            mae_e_data_location = os.path.join('mae_e_data', selected_model, retry_idx, f'test_mae_e_data_{selected_dataset}.pkl')
            f_data = open(mae_e_data_location, 'rb')
            data = pkl.load(f_data)
            f_data.close()
            
            pm_split_by_marks = [[] for _ in range(dict_num_marks[selected_dataset])]
            pm_value_while_happens = [[] for _ in range(dict_num_marks[selected_dataset])]
            for event_next_per_seq, p_m_per_seq in zip(data['events_next'], data['pm']):
                p_m_per_seq = p_m_per_seq[0]
                for each_next_event, pm in zip(event_next_per_seq, p_m_per_seq):
                    pm_value_while_happens[each_next_event].append(pm[each_next_event])
            
            pm_arrays = [np.array(sublist).squeeze() for sublist in data['pm']]
            reshaped_pm = np.concatenate(pm_arrays, axis = 0)
            sample_size, _  = reshaped_pm.shape

            for per_mark in range(dict_num_marks[selected_dataset]):
                selected_probability_list = pm_value_while_happens[per_mark]
                if not write_mark_size:
                    dataset_size_for_each_mark[per_mark] = len(selected_probability_list)
                all_pm = reshaped_pm[:, per_mark]
                percentage_of_one_mark = []
                for value in selected_probability_list:
                    percentage_of_one_mark.append((all_pm < value).astype(int).sum() / sample_size)
                
                results[per_mark].append(percentage_of_one_mark)

            write_mark_size = True

        for mark_idx, sub_mae_e_list in results.items():
            tmp_sub_mae_e_list = np.array(sub_mae_e_list)
            q123_and_mean = np.concatenate((np.percentile(tmp_sub_mae_e_list, [25, 50, 75], axis = -1), np.expand_dims(np.mean(tmp_sub_mae_e_list, axis = -1), axis = 0)), axis = 0)
            mean = q123_and_mean.mean(axis = -1)
            std = q123_and_mean.std(axis = -1)
            f_text.write(f'For mark {mark_idx}, we have {dataset_size_for_each_mark[mark_idx]} mae_e records.\n')
            f_text.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7} \n'.format(*mean, *std))