import os
import numpy as np
import pickle as pkl


root_path = os.path.dirname(os.path.abspath(__file__))

datasets = ['stackoverflow', 'bookorder', 'taobao', \
            'retweet', 'stackoverflow', 'usearthquake', 'yelp']
# datasets = ['usearthquake']

dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}
retries = ['1', '2', '3']


for selected_dataset in datasets:
    count_marks = True
    results = {idx: [] for idx in range(dict_num_marks[selected_dataset])}
    dataset_size_for_each_mark = None

    for retry_idx in retries:
        if count_marks:
            dataset_size_for_each_mark = [0 for _ in range(dict_num_marks[selected_dataset])]

        mae_e_data_location = os.path.join(root_path, 'mae_e_by_time_event', \
                                           retry_idx, f'test_mae_e_by_time_event_{selected_dataset}.pkl')
        f_data = open(mae_e_data_location, 'rb')
        data = pkl.load(f_data)
        f_data.close()
        
        mae_e_split_by_marks = [[] for _ in range(dict_num_marks[selected_dataset])]
        for mae_e_per_seq, event_next_per_seq in zip(data['mae_e'], data['event_next']):
            for each_mae_e, each_next_event in zip(mae_e_per_seq, event_next_per_seq):
                mae_e_split_by_marks[each_next_event].append(each_mae_e)
                if count_marks:
                    dataset_size_for_each_mark[each_next_event] += 1
                    
        for mark, each_mae_e in enumerate(mae_e_split_by_marks):
            values = [*np.percentile(each_mae_e, [25, 50, 75]), np.mean(each_mae_e)]
            results[mark].append(values)
        
        del data
        count_marks = False

    text = os.path.join(root_path, 'mae_e_by_time_event', f'{selected_dataset}_mae_e_by_marks.txt')
    f_text = open(text, 'w')
    f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')
    for mark_idx, sub_mae_e_list in results.items():
        tmp_sub_mae_e_list = np.array(sub_mae_e_list)
        mean = tmp_sub_mae_e_list.mean(axis = 0)
        std = tmp_sub_mae_e_list.std(axis = 0)
        f_text.write(f'For mark {mark_idx}, we have {dataset_size_for_each_mark[mark_idx]} mae_e records.\n')
        f_text.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7}. \n'.format(*mean, *std))