#!/usr/bin/env python
# coding: utf-8

import os
import argparse
from multiprocessing import Pool

import time
import scipy.stats
from einops import pack, repeat
from tqdm import tqdm

import numpy as np
import pickle as pkl
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

dict_mark_number = {
    'retweet': 3,
    'yelp': 3,
    'stackoverflow': 22
}

parser = argparse.ArgumentParser()

parser.add_argument('--dataset', type = str, help = 'Evaluate mark distribution on this dataset.')
parser.add_argument('--seq_x', type = int, help = 'The sequence length of x.')
parser.add_argument('--seq_h', type = int, help = 'The sequence length of H.')
parser.add_argument('--batch_size', type = int, help = 'The batch size used to train the EHD-MTPP model.')
parser.add_argument('--model_name', type = str, help = 'The specific name of the distiller.')

opt = parser.parse_args()

root = '/home/stepinsilence/project/workflow_clean'
result_rootdir = 'results'
procedure_name = 'ehd'
model_name = opt.model_name
backend = 'fenn'
dataset_name = opt.dataset
seq_x = opt.seq_x
seq_h = opt.seq_h
batch_size = opt.batch_size
mark_number = dict_mark_number[opt.dataset]

dataset_name = f'{dataset_name}_{seq_x}_{seq_h}'
model_folder = f'results_{model_name}_lr0.001_bs{batch_size}_nts100000_{backend}_dl.yml_ehd_{backend}{"_no_reverse" if model_name == "ehd_perplexity" else ""}.yml'
result_pkl = 'test_data_fast.pkl'

sample_rate = 1024
sample_number_of_retries = 64
process = 3

result_pkl_path = os.path.join(root, result_rootdir, procedure_name, dataset_name, model_folder, result_pkl)

f = open(result_pkl_path, 'rb')
result = pkl.load(f)
f.close()

'''
output_name = [
    'percentage_remained_events', 'L_sp', 'l_sp_random', 'L_rp',
    'l_rp_random', 'time_baseline_1_given_percentage_to_ehd', 
    'history_mask', 'time_history', 'time_future', 'events_history', 'events_future'
]
'''

history_mask = result['history_mask']
history_mask = np.array(history_mask, dtype = np.int32).squeeze(axis = 1)

time_history = result['time_history']
time_history = np.array(time_history).squeeze(axis = 1)

time_future = result['time_future']
time_future = np.array(time_future).squeeze(axis = 1)

events_history = result['events_history']
events_history = np.array(events_history).squeeze(axis = 1)

events_future = result['events_future']
events_future = np.array(events_future).squeeze(axis = 1)


# Check out how many each type of events are removed.

removed_event_mask = history_mask[..., 1]
# removed_event_times = time_history * removed_event_mask
# reversed_time = removed_event_times.sum(axis = -1, keepdims = True) - removed_event_times.cumsum(axis = -1)
removed_events = np.ma.array(events_history, mask = 1 - removed_event_mask)
removed_events = removed_events.filled(fill_value = mark_number + 1)
# Remove the first dummy event.
removed_events = removed_events[:, 1:]

removed_events_distribution = np.unique(removed_events, return_counts=True)
original_events_distribution = np.unique(events_history[:, 1:], return_counts=True)


percentage_model = []
for event_id in range(mark_number):
    removed_percentage = removed_events_distribution[1][event_id] / original_events_distribution[1][event_id]
    print(f'For event type {event_id}, {removed_percentage * 100}% of events have been removed.')
    percentage_model.append(removed_percentage)


# We have to compare it with random removal to certain the removed distribution is non-trivial.

num_removed_events = removed_event_mask.sum(axis = -1) - 1
data_size, seq_len = removed_event_mask.shape


def sample_new(*args):
    removed_events_list = []
    for num_removed_event, events_history_seq in zip(num_removed_events, events_history):
        # selected_idx = []
        # for _ in range(sample_rate):
        #     selected_idx.append(np.random.choice(seq_len - 1, num_removed_event, replace = False))
        # selected_idx = np.stack(selected_idx, axis = 0)
        
        idx = repeat(np.arange(seq_len - 1), 'l -> sr l', sr = sample_rate)
        rng = np.random.default_rng()
        for row in idx:
            rng.shuffle(row, axis = -1)
        selected_idx = idx[..., :num_removed_event]
        empty_mask = np.zeros((sample_rate, seq_len - 1))
        np.put_along_axis(empty_mask, selected_idx, 1, axis = -1)
        assert (empty_mask.sum(axis = -1) == num_removed_event).all()

        removed_events = np.ma.array(repeat(events_history_seq[1:], 's -> sr s', sr = sample_rate), mask = 1 - empty_mask)
        removed_events = removed_events.filled(fill_value = mark_number + 1)
        event_counts = [(removed_events == mark).astype(int).sum() for mark in range(mark_number)]
        
        removed_events_list.append(event_counts)
    
    sum_of_removed_events = np.array(removed_events_list).sum(axis = 0)
    return sum_of_removed_events

start_time = time.time()
with Pool(processes = process) as pool:
    result = pool.map(sample_new, [None,] * sample_number_of_retries)
end_time = time.time()
print(f'Running Time: {end_time - start_time}s.')

result = np.stack(result, axis = 0)
result = result / sample_rate
result = result.tolist()

all_percentage_random = []
for idx in range(sample_number_of_retries):
    percentage_random = []
    for event_id in range(mark_number):
        removed_percentage = result[idx][event_id] / original_events_distribution[1][event_id]
        # print(f'For event type {event_id}, {removed_percentage * 100}% of events have been removed.')
        percentage_random.append(removed_percentage)

    all_percentage_random.append(percentage_random)

all_percentage_random = np.array(all_percentage_random)


p_vals = []
p_val_file = f'p_val_{dataset_name}_{model_name}.txt'
f = open(p_val_file, 'w')

for mark in range(mark_number):
    mean, var = scipy.stats.norm.fit(all_percentage_random[:, mark])
    cdf_value = scipy.stats.norm.cdf(percentage_model[mark], loc = mean, scale = var)
    p_val = 2 * min(cdf_value, 1 - cdf_value)
    p_vals.append(p_val)
    print(f'The p-val for mark {mark} is {p_val}.')
    f.write(f'The p-val for mark {mark} is {p_val}.\n')

f.close()

offset = 0

all_percentage, _ = pack((np.array(percentage_model), all_percentage_random), '* s')
dict_data = {
    'Mark': list(range(mark_number)) * (1 + sample_number_of_retries),
    'Distilled Percentage': all_percentage.flatten() * 100 - offset,
    'Method': ['EHD-MTPP'] * mark_number + ['RD'] * mark_number * sample_number_of_retries
}
df_data = pd.DataFrame.from_dict(dict_data)

plt.rcParams.update({
    'font.size': 14,
    'figure.figsize': (2, 2),
    'text.usetex': True,
    'font.family': 'serif',
    'font.serif': 'Times',
    'mathtext.fontset': 'dejavusans',
    'text.latex.preamble': r"\usepackage{amsmath}"
})

fig = plt.figure()
ax = sns.barplot(df_data, x = 'Mark', y = 'Distilled Percentage', hue = 'Method', palette = "Set2", \
                 errorbar = ("pi", 50), capsize = 0.1)
if opt.dataset == 'stackoverflow':
    print('reset the font size of ticks on the x axis.')
    plt.xticks(fontsize = 6)
    plt.legend(fontsize = 6)

plt.savefig(os.path.join(f'removal_percentage_{dataset_name}_{model_name}.pdf'), bbox_inches = "tight")
fig.clf()
plt.close(fig = fig)
del ax
import gc
gc.collect()
