#!/usr/bin/env python
# coding: utf-8

# In[11]:


import json, os, yaml, glob, io
import pandas as pd
import numpy as np


# In[12]:

root_path = os.path.dirname(os.path.abspath(__file__))
root = os.path.join(root_path, '..', 'data')
src_datasets_dir = 'TPP'
tag_datasets_dir = 'ehd'
length_of_x = 20
length_of_h = 35

src_datasets_dir = os.path.join(root, src_datasets_dir)
tag_datasets_dir = os.path.join(root, tag_datasets_dir)
# datasets = os.listdir(src_datasets_dir)
# datasets = ['hawkes_1_v2', 'hawkes_2_v2', 'poisson_v2', 'self_correct_v2', 'stationary_renewal_v2']  # 24 40
# datasets = ['bookorder']                                                       # 5 [15, 30, 5]
# datasets = ['stackoverflow']                                                   # [15, 50], [15, 45], [15, 40], [20, 50], [25, 50]
# datasets = ['mooc']                                                            # 15 [30, 50, 5]
# datasets = ['retweet']                                                         # [10, 25], [10, 30], [10, 35], [15, 35], [20, 35]
datasets = ['yelp']                                                         # [10, 25], [10, 30], [10, 35], [15, 35], [20, 35]


# In[14]:


for sample_dataset in datasets:
    sample_dataset_path = os.path.join(src_datasets_dir, sample_dataset)
    tag_dataset_path = os.path.join(tag_datasets_dir, '_'.join([sample_dataset, str(length_of_x), str(length_of_h)]))
    if not os.path.exists(tag_dataset_path):
        os.makedirs(tag_dataset_path)

    # load the property card
    f_property = open(os.path.join(sample_dataset_path, 'dataset_card.yml'), 'r')
    dataset_card = yaml.safe_load(f_property)
    f_property.close()

    # write new properties to the dataset_card
    dataset_card['length_of_x'] = length_of_x
    dataset_card['length_of_h'] = length_of_h

    sample_dataset_basename = [os.path.basename(item) for item in glob.glob(sample_dataset_path + f'/*.json')]

    for basename in sample_dataset_basename:
        print(f'Now processing {basename} of {sample_dataset}.')
        sample_dataset_filename = os.path.join(sample_dataset_path, basename)
        
        f_data = open(sample_dataset_filename, 'r')
        dataset_json = json.load(f_data)
        f_data.close()
        
        # get properties we need.
        start_time = dataset_card['t_0']
        end_time = dataset_card['T']
        num_events = dataset_card['num_events']
        
        df_dataset = pd.DataFrame.from_dict(dataset_json)
        
        ehd_time_seq = []
        ehd_event_seq = []
        ehd_intensity = []
        ehd_score = []
        ehd_mask = []
        
        for idx, each_item in df_dataset.iterrows():
            time_seq = each_item['time_seq']
            event_seq = each_item['event']
            intensity = each_item['intensity']
            score = each_item['score']

            time_seq = np.array(time_seq, dtype = np.float32)
            event_seq = np.array(event_seq, dtype = np.int32)
            intensity = np.array(intensity, dtype = np.float32)
            score = np.array(score, dtype = np.float32)
            time_seq = np.diff(time_seq, prepend = start_time)

            assert (event_seq >= 0).all(), 'negative in data!'

            seq_len = len(time_seq)
        
            # We will ignore the sequence that is too short.
            if seq_len - length_of_x - length_of_h < 0:
                continue
        
            number_of_seqs = range(max(seq_len - length_of_x - length_of_h, 1))
            for start_idx in number_of_seqs:
                tmp_time_seq = time_seq[start_idx:start_idx + length_of_x + length_of_h]
                tmp_event_seq = event_seq[start_idx:start_idx + length_of_x + length_of_h]
                tmp_intensity = intensity[start_idx:start_idx + length_of_x + length_of_h]
                tmp_score = score[start_idx:start_idx + length_of_x + length_of_h]
                tmp_mask = np.ones(length_of_x + length_of_h)
        
                assert (tmp_event_seq >= 0).all(), 'negative in case 2!'

                ehd_time_seq.append(tmp_time_seq.tolist())
                ehd_event_seq.append(tmp_event_seq.tolist())
                ehd_intensity.append(tmp_intensity.tolist())
                ehd_score.append(tmp_score.tolist())
                ehd_mask.append(tmp_mask.tolist())

        ehd_data_dict = {
            'time_seq': ehd_time_seq,
            'event': ehd_event_seq,
            'score': ehd_score,
            'intensity': ehd_intensity,
            'mask': ehd_mask
        }
        ehd_df = pd.DataFrame.from_dict(ehd_data_dict)
        print(ehd_df.shape[0])
        ehd_df.to_json(os.path.join(tag_dataset_path, basename))
    
    with io.open(os.path.join(tag_dataset_path, 'dataset_card.yml'), 'w', encoding = 'utf8') as outfile:
        yaml.dump(dataset_card, outfile, default_flow_style=False, allow_unicode=True)
        


# In[ ]:




