import os
import numpy as np
import pickle as pkl

root_path = os.path.dirname(os.path.abspath(__file__))

datasets = ['bookorder', 'taobao', 'retweet', 'stackoverflow', 'usearthquake', 'yelp']
# datasets = ['usearthquake']
retries = ['1', '2']

dict_num_marks = {
    'bookorder': 2,
    'retweet': 3,
    'stackoverflow': 22,
    'taobao': 17,
    'usearthquake': 7,
    'yelp': 3
}

freq_marks = {
    'bookorder': [0],
    'retweet': [0, 1],
    'stackoverflow': [0, 3, 4, 5, 8, ],
    'taobao': [16, ],
    'usearthquake': [0, 1, 2],
    'yelp': [1, ]
}

rare_marks = {
    'bookorder': [1],
    'retweet': [2, ],
    'stackoverflow': [1, 2, 6, 7, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
    'taobao': list(range(16)),
    'usearthquake': [3, 4, 5, 6],
    'yelp': [0, 2]
}

for selected_dataset in datasets:
    results_freq = []
    results_rare = []
    dataset_size_for_each_mark = None

    for retry_idx in retries:
        mae_e_data_location = os.path.join(root_path, 'mae_e_by_time_event', \
                                           retry_idx, f'test_mae_e_by_time_event_{selected_dataset}.pkl')
        f_data = open(mae_e_data_location, 'rb')
        data = pkl.load(f_data)
        f_data.close()
        
        mae_e_rare = []
        mae_e_frequent = []

        for mae_e_per_seq, event_next_per_seq in zip(data['mae_e'], data['event_next']):
            for each_mae_e, each_next_event in zip(mae_e_per_seq, event_next_per_seq):
                if each_next_event in freq_marks[selected_dataset]:
                    mae_e_frequent.append(each_mae_e)
                elif each_next_event in rare_marks[selected_dataset]:
                    mae_e_rare.append(each_mae_e)
        
        if len(mae_e_rare) > 0:
            mae_e_rare = np.array(mae_e_rare)
            missing_items = mae_e_rare[mae_e_rare > 5e5]
            missing_percentage_rare = missing_items.shape[0] / mae_e_rare.shape[0]
        else:
            missing_percentage_rare = None
        
        if len(mae_e_frequent) > 0:
            mae_e_frequent = np.array(mae_e_frequent)
            missing_items = mae_e_frequent[mae_e_frequent > 5e5]
            missing_percentage_freq = missing_items.shape[0] / mae_e_frequent.shape[0]
        else:
            missing_percentage_freq = None

        if missing_percentage_freq is not None:
            results_freq.append(missing_percentage_freq)
        if missing_percentage_rare is not None:
            results_rare.append(missing_percentage_rare)
        del data

    text = os.path.join(root_path, 'mae_e_by_time_event', f'{selected_dataset}_mae_e_freq_rare_percent.txt')
    f_text = open(text, 'w')
    f_text.write(f'Dataset {selected_dataset} has {dict_num_marks[selected_dataset]} marks.\n')
    
    if len(results_freq) > 0:
        tmp_sub_mae_e_list = np.array(results_freq)
        mean = tmp_sub_mae_e_list.mean(axis = 0)
        std = tmp_sub_mae_e_list.std(axis = 0)
        f_text.write(f'For frequent marks.\n')
        f_text.write('Percentage: {0}%±{1}% \n'.format(mean * 100, std * 100))

    if len(results_rare) > 0:
        tmp_sub_mae_e_list = np.array(results_rare)
        mean = tmp_sub_mae_e_list.mean(axis = 0)
        std = tmp_sub_mae_e_list.std(axis = 0)
        f_text.write(f'For rare marks.\n')
        f_text.write('Percentage: {0}%±{1}% \n'.format(mean * 100, std * 100))

    f_text.close()