import os
import numpy as np
import pickle as pkl

root_path = os.path.dirname(os.path.abspath(__file__))

datasets = ['bookorder', 'retweet', 'stackoverflow', 'taobao', 'usearthquake', 'yelp']
model_name = ['fullynn', 'fenn', 'ifn', 'thp', 'sahp']
retries = ['1', '2', '3']

dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}

freq_marks = {
    'bookorder': [0],
    'retweet': [0, 1],
    'stackoverflow': [0, 3, 4, 5, 8, ],
    'taobao': [16, ],
    'usearthquake': [0, 1, 2],
    'yelp': [1, ]
}

rare_marks = {
    'bookorder': [1],
    'retweet': [2, ],
    'stackoverflow': [1, 2, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
    'taobao': list(range(16)),
    'usearthquake': [3, 4, 5, 6],
    'yelp': [0, 2]
}

for selected_dataset in datasets:
    for selected_model in model_name:
        results = {idx: [] for idx in range(dict_num_marks[selected_dataset])}
        results_for_freq_mark = []
        results_for_rare_mark = []

        text = os.path.join(root_path, 'mae_e_data', selected_model, f'{selected_dataset}_mae_e_by_marks.txt')
        f_text = open(text, 'w')
        f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')

        text_freq_rare = os.path.join(root_path, 'mae_e_data', selected_model, f'{selected_dataset}_mae_e_by_freqs.txt')
        f_text_freq_rare = open(text_freq_rare, 'w')

        write_mark_size = False
        dataset_size_for_each_mark = [0 for _ in range(dict_num_marks[selected_dataset])]


        for retry_idx in retries:
            dataset_size_for_freq_mark = []
            dataset_size_for_rare_mark = []
            mae_e_data_location = os.path.join(root_path, 'mae_e_data', selected_model, retry_idx, f'test_mae_e_data_{selected_dataset}.pkl')
            f_data = open(mae_e_data_location, 'rb')
            data = pkl.load(f_data)
            f_data.close()
            
            mae_e_split_by_marks = [[] for _ in range(dict_num_marks[selected_dataset])]
            for mae_e_per_seq, event_next_per_seq in zip(data['mae_e'], data['events_next']):
                for each_mae_e, each_next_event in zip(mae_e_per_seq, event_next_per_seq):
                    mae_e_split_by_marks[each_next_event].append(each_mae_e)
            
            for mark_idx, sub_mae_e_list in enumerate(mae_e_split_by_marks):
                tmp_sub_mae_e_list = np.array(sub_mae_e_list)
                if mark_idx in freq_marks[selected_dataset]:
                    dataset_size_for_freq_mark += sub_mae_e_list
                if mark_idx in rare_marks[selected_dataset]:
                    dataset_size_for_rare_mark += sub_mae_e_list
                if not write_mark_size:
                    dataset_size_for_each_mark[mark_idx] = tmp_sub_mae_e_list.shape[0]
                values = [*np.percentile(tmp_sub_mae_e_list, [25, 50, 75]), tmp_sub_mae_e_list.mean()]
                results[mark_idx].append(values)
            
            f_text_freq_rare.write(f'{len(dataset_size_for_freq_mark)} events have frequent marks.\n')
            f_text_freq_rare.write(f'{len(dataset_size_for_rare_mark)} events have rare marks.\n')
            if len(dataset_size_for_freq_mark) > 0:
                freq_values = [*np.percentile(dataset_size_for_freq_mark, [25, 50, 75]), np.mean(dataset_size_for_freq_mark)]
                results_for_freq_mark.append(freq_values)
            if len(dataset_size_for_rare_mark) > 0:
                rare_values = [*np.percentile(dataset_size_for_rare_mark, [25, 50, 75]), np.mean(dataset_size_for_rare_mark)]
                results_for_rare_mark.append(rare_values)

            write_mark_size = True

        for mark_idx, sub_mae_e_list in results.items():
            tmp_sub_mae_e_list = np.array(sub_mae_e_list)
            mean = tmp_sub_mae_e_list.mean(axis = 0)
            std = tmp_sub_mae_e_list.std(axis = 0)
            f_text.write(f'For mark {mark_idx}, we have {dataset_size_for_each_mark[mark_idx]} mae_e records.\n')
            f_text.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7}. \n'.format(*mean, *std))
        
        tmp_results_for_freq_mark = np.array(results_for_freq_mark)
        mean_freq = tmp_results_for_freq_mark.mean(axis = 0)
        std_freq = tmp_results_for_freq_mark.std(axis = 0)
        tmp_results_for_rare_mark = np.array(results_for_rare_mark)
        mean_rare = tmp_results_for_rare_mark.mean(axis = 0)
        std_rare = tmp_results_for_rare_mark.std(axis = 0)
        
        if len(results_for_freq_mark) > 0:
            f_text_freq_rare.write(f'For frequent marks:\n')
            f_text_freq_rare.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7}. \n'.format(*mean_freq, *std_freq))
        else:
            f_text_freq_rare.write(f'No frequent mark has been defined.')

        if len(results_for_rare_mark) > 0:
            f_text_freq_rare.write(f'For rare marks:\n')
            f_text_freq_rare.write('Q1: {0}±{4}, Q2: {1}±{5}, Q3: {2}±{6}, mean: {3}±{7}. \n'.format(*mean_rare, *std_rare))
        else:
            f_text_freq_rare.write(f'No rare mark has been defined.')
        
        f_text.close()
        f_text_freq_rare.close()